├── .circleci └── config.yml ├── .gitattributes ├── .github ├── bench.py ├── bootstrap.sh ├── optimize_performance.sh └── workflows │ ├── bench_job.yml │ └── benchmarks.yml ├── .gitignore ├── .style.yapf ├── LICENSE ├── README.md ├── configure_mlx.sh ├── docs ├── exo-logo-black-bg.jpg ├── exo-logo-transparent-black-text.png ├── exo-logo-transparent.png ├── exo-rounded.png └── exo-screenshot.jpg ├── examples ├── astra │ ├── README.md │ ├── astra.xcodeproj │ │ ├── project.pbxproj │ │ └── project.xcworkspace │ │ │ ├── contents.xcworkspacedata │ │ │ └── xcshareddata │ │ │ ├── IDEWorkspaceChecks.plist │ │ │ └── swiftpm │ │ │ └── Package.resolved │ ├── astra │ │ ├── Assets.xcassets │ │ │ ├── AccentColor.colorset │ │ │ │ └── Contents.json │ │ │ ├── AppIcon.appiconset │ │ │ │ └── Contents.json │ │ │ └── Contents.json │ │ ├── ContentView.swift │ │ ├── Preview Content │ │ │ └── Preview Assets.xcassets │ │ │ │ └── Contents.json │ │ ├── astra.entitlements │ │ └── astraApp.swift │ ├── astraTests │ │ └── astraTests.swift │ └── astraUITests │ │ ├── astraUITests.swift │ │ └── astraUITestsLaunchTests.swift ├── chatgpt_api.sh └── function_calling.py ├── exo ├── __init__.py ├── api │ ├── __init__.py │ └── chatgpt_api.py ├── apputil │ ├── __init__.py │ ├── anim.py │ └── baseimages │ │ ├── image1.png │ │ ├── image2.png │ │ ├── image3.png │ │ └── image4.png ├── download │ ├── __init__.py │ ├── download_progress.py │ ├── hf │ │ ├── __init__.py │ │ └── hf_helpers.py │ ├── new_shard_download.py │ ├── shard_download.py │ └── test_new_shard_download.py ├── helpers.py ├── inference │ ├── __init__.py │ ├── debug_inference_engine.py │ ├── dummy_inference_engine.py │ ├── inference_engine.py │ ├── mlx │ │ ├── __init__.py │ │ ├── losses.py │ │ ├── models │ │ │ ├── StableDiffusionPipeline.py │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── deepseek_v2.py │ │ │ ├── deepseek_v3.py │ │ │ ├── gemma2.py │ │ │ ├── llama.py │ │ │ ├── llava.py │ │ │ ├── phi3.py │ │ │ ├── qwen2.py │ │ │ └── sd_models │ │ │ │ ├── clip.py │ │ │ │ ├── tokenizer.py │ │ │ │ ├── unet.py │ │ │ │ └── vae.py │ │ ├── perf_improvements.md │ │ ├── sharded_inference_engine.py │ │ ├── sharded_utils.py │ │ ├── test_non_blocking.py │ │ └── test_sharded_model.py │ ├── shard.py │ ├── test_dummy_inference_engine.py │ ├── test_inference_engine.py │ ├── tinygrad │ │ ├── __init__.py │ │ ├── inference.py │ │ ├── losses.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── llama.py │ │ ├── stateful_model.py │ │ └── tinygrad_helpers.py │ └── tokenizers.py ├── main.py ├── models.py ├── networking │ ├── __init__.py │ ├── discovery.py │ ├── grpc │ │ ├── __init__.py │ │ ├── grpc_peer_handle.py │ │ ├── grpc_server.py │ │ ├── node_service.proto │ │ ├── node_service_pb2.py │ │ └── node_service_pb2_grpc.py │ ├── manual │ │ ├── __init__.py │ │ ├── manual_discovery.py │ │ ├── network_topology_config.py │ │ ├── test_data │ │ │ ├── invalid_config.json │ │ │ ├── invalid_json.json │ │ │ ├── test_config.json │ │ │ └── test_config_single_node.json │ │ ├── test_manual_discovery.py │ │ └── test_network_topology_config.py │ ├── peer_handle.py │ ├── server.py │ ├── tailscale │ │ ├── __init__.py │ │ ├── tailscale_discovery.py │ │ ├── tailscale_helpers.py │ │ └── test_tailscale_discovery.py │ └── udp │ │ ├── __init__.py │ │ ├── test_udp_discovery.py │ │ └── udp_discovery.py ├── orchestration │ ├── __init__.py │ ├── node.py │ ├── test_node.py │ └── tracing.py ├── test_callbacks.py ├── tinychat │ ├── common.css │ ├── favicon.svg │ ├── index.css │ ├── index.html │ ├── index.js │ ├── static │ │ ├── cdn.jsdelivr.net │ │ │ └── npm │ │ │ │ ├── @alpine-collective │ │ │ │ └── toolkit@1.0.2 │ │ │ │ │ └── dist │ │ │ │ │ └── cdn.min.js │ │ │ │ ├── @alpinejs │ │ │ │ ├── focus@3.x.x │ │ │ │ │ └── dist │ │ │ │ │ │ └── cdn.min.js │ │ │ │ └── intersect@3.x.x │ │ │ │ │ └── dist │ │ │ │ │ └── cdn.min.js │ │ │ │ └── purecss@3.0.0 │ │ │ │ └── build │ │ │ │ └── base-min.css │ │ ├── cdnjs.cloudflare.com │ │ │ └── ajax │ │ │ │ └── libs │ │ │ │ └── font-awesome │ │ │ │ └── 6.5.2 │ │ │ │ ├── css │ │ │ │ └── all.min.css │ │ │ │ └── webfonts │ │ │ │ ├── fa-brands-400.ttf │ │ │ │ ├── fa-brands-400.woff2 │ │ │ │ ├── fa-regular-400.ttf │ │ │ │ ├── fa-regular-400.woff2 │ │ │ │ ├── fa-solid-900.ttf │ │ │ │ ├── fa-solid-900.woff2 │ │ │ │ ├── fa-v4compatibility.ttf │ │ │ │ └── fa-v4compatibility.woff2 │ │ ├── fonts.googleapis.com │ │ │ └── css2 │ │ └── unpkg.com │ │ │ ├── @highlightjs │ │ │ └── cdn-assets@11.9.0 │ │ │ │ ├── highlight.min.js │ │ │ │ └── styles │ │ │ │ └── vs2015.min.css │ │ │ ├── @marcreichel │ │ │ └── alpine-autosize@1.3.x │ │ │ │ └── dist │ │ │ │ └── alpine-autosize.min.js │ │ │ ├── alpinejs@3.x.x │ │ │ └── dist │ │ │ │ └── cdn.min.js │ │ │ ├── dompurify@3.1.5 │ │ │ └── dist │ │ │ │ └── purify.min.js │ │ │ ├── marked-highlight@2.1.2 │ │ │ └── lib │ │ │ │ └── index.umd.js │ │ │ └── marked@13.0.0 │ │ │ └── marked.min.js │ └── update_deps.py ├── topology │ ├── __init__.py │ ├── device_capabilities.py │ ├── partitioning_strategy.py │ ├── ring_memory_weighted_partitioning_strategy.py │ ├── test_device_capabilities.py │ ├── test_map_partitions.py │ ├── test_ring_memory_weighted_partitioning_strategy.py │ └── topology.py ├── train │ ├── __init__.py │ ├── data │ │ └── lora │ │ │ ├── test.jsonl │ │ │ ├── train.jsonl │ │ │ └── valid.jsonl │ └── dataset.py └── viz │ ├── __init__.py │ ├── test_topology_viz.py │ └── topology_viz.py ├── extra ├── dashboard │ ├── dashboard.py │ ├── requirements.txt │ └── sounds │ │ ├── gta5_wasted.mp3 │ │ └── pokemon_evolve.mp3 ├── line_counter.py ├── pipsize.py └── start_openwebui.sh ├── format.py ├── install.sh ├── scripts ├── build_exo.py └── compile_grpc.sh ├── setup.py └── test ├── reconnect.sh ├── test_model_helpers.py └── test_tokenizers.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.mp3 filter=lfs diff=lfs merge=lfs -text 2 | *.png filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.github/optimize_performance.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Function to log with timestamp 5 | log() { 6 | echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" 7 | } 8 | 9 | log "Applying comprehensive performance optimizations..." 10 | 11 | # System-wide power management 12 | log "Configuring power management..." 13 | sudo pmset -a lessbright 0 14 | sudo pmset -a disablesleep 1 15 | sudo pmset -a sleep 0 16 | sudo pmset -a hibernatemode 0 17 | sudo pmset -a autopoweroff 0 18 | sudo pmset -a standby 0 19 | sudo pmset -a powernap 0 20 | sudo pmset -a proximitywake 0 21 | sudo pmset -a tcpkeepalive 1 22 | sudo pmset -a powermode 2 23 | sudo pmset -a gpuswitch 2 24 | sudo pmset -a displaysleep 0 25 | sudo pmset -a disksleep 0 26 | 27 | # Memory and kernel optimizations 28 | log "Configuring memory and kernel settings..." 29 | sudo sysctl -w kern.memorystatus_purge_on_warning=0 30 | sudo sysctl -w kern.memorystatus_purge_on_critical=0 31 | sudo sysctl -w kern.timer.coalescing_enabled=0 32 | 33 | # Metal and GPU optimizations 34 | log "Configuring Metal and GPU settings..." 35 | defaults write com.apple.CoreML MPSEnableGPUValidation -bool false 36 | defaults write com.apple.CoreML MPSEnableMetalValidation -bool false 37 | defaults write com.apple.CoreML MPSEnableGPUDebug -bool false 38 | defaults write com.apple.Metal GPUDebug -bool false 39 | defaults write com.apple.Metal GPUValidation -bool false 40 | defaults write com.apple.Metal MetalValidation -bool false 41 | defaults write com.apple.Metal MetalCaptureEnabled -bool false 42 | defaults write com.apple.Metal MTLValidationBehavior -string "Disabled" 43 | defaults write com.apple.Metal EnableMTLDebugLayer -bool false 44 | defaults write com.apple.Metal MTLDebugLevel -int 0 45 | defaults write com.apple.Metal PreferIntegratedGPU -bool false 46 | defaults write com.apple.Metal ForceMaximumPerformance -bool true 47 | defaults write com.apple.Metal MTLPreferredDeviceGPUFrame -bool true 48 | 49 | # Create MPS cache directory with proper permissions 50 | sudo mkdir -p /tmp/mps_cache 51 | sudo chmod 777 /tmp/mps_cache 52 | 53 | # Process and resource limits 54 | log "Configuring process limits..." 55 | sudo launchctl limit maxfiles 524288 524288 56 | ulimit -n 524288 || log "Warning: Could not set file descriptor limit" 57 | ulimit -c 0 58 | ulimit -l unlimited || log "Warning: Could not set memory lock limit" 59 | 60 | # Export performance-related environment variables 61 | cat << 'EOF' > /tmp/performance_env.sh 62 | # Metal optimizations 63 | export MTL_DEBUG_LAYER=0 64 | export METAL_DEVICE_WRAPPER_TYPE=1 65 | export METAL_DEBUG_ERROR_MODE=0 66 | export METAL_FORCE_PERFORMANCE_MODE=1 67 | export METAL_DEVICE_PRIORITY=high 68 | export METAL_MAX_COMMAND_QUEUES=1024 69 | export METAL_LOAD_LIMIT=0 70 | export METAL_VALIDATION_ENABLED=0 71 | export METAL_ENABLE_VALIDATION_LAYER=0 72 | export OBJC_DEBUG_MISSING_POOLS=NO 73 | export MPS_CACHEDIR=/tmp/mps_cache 74 | 75 | # MLX optimizations 76 | export MLX_USE_GPU=1 77 | export MLX_METAL_COMPILE_ASYNC=1 78 | export MLX_METAL_PREALLOCATE=1 79 | export MLX_METAL_MEMORY_GUARD=0 80 | export MLX_METAL_CACHE_KERNELS=1 81 | export MLX_PLACEMENT_POLICY=metal 82 | export MLX_METAL_VALIDATION=0 83 | export MLX_METAL_DEBUG=0 84 | export MLX_FORCE_P_CORES=1 85 | export MLX_METAL_MEMORY_BUDGET=0 86 | export MLX_METAL_PREWARM=1 87 | 88 | # Python optimizations 89 | export PYTHONUNBUFFERED=1 90 | export PYTHONOPTIMIZE=2 91 | export PYTHONHASHSEED=0 92 | export PYTHONDONTWRITEBYTECODE=1 93 | EOF 94 | 95 | log "Performance optimizations completed. Environment variables written to /tmp/performance_env.sh" -------------------------------------------------------------------------------- /.github/workflows/benchmarks.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: 4 | push: 5 | branches: [ '*' ] 6 | tags: [ '*' ] 7 | pull_request: 8 | branches: [ '*' ] 9 | 10 | jobs: 11 | single-m4-pro: 12 | strategy: 13 | matrix: 14 | model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b'] 15 | uses: ./.github/workflows/bench_job.yml 16 | with: 17 | config: '{"M4PRO_GPU16_24GB": 1}' 18 | model: ${{ matrix.model }} 19 | calling_job_name: 'single-m4-pro' 20 | network_interface: 'Ethernet' 21 | secrets: inherit 22 | 23 | two-m4-pro-cluster: 24 | strategy: 25 | matrix: 26 | model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b'] 27 | uses: ./.github/workflows/bench_job.yml 28 | with: 29 | config: '{"M4PRO_GPU16_24GB": 2}' 30 | model: ${{ matrix.model }} 31 | calling_job_name: 'two-m4-pro-cluster' 32 | network_interface: 'Ethernet' 33 | secrets: inherit 34 | 35 | # two-m4-pro-cluster-thunderbolt: 36 | # strategy: 37 | # matrix: 38 | # model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b'] 39 | # uses: ./.github/workflows/bench_job.yml 40 | # with: 41 | # config: '{"M4PRO_GPU16_24GB": 2}' 42 | # model: ${{ matrix.model }} 43 | # calling_job_name: 'two-m4-pro-cluster-thunderbolt' 44 | # network_interface: 'Thunderbolt' 45 | # secrets: inherit 46 | 47 | three-m4-pro-cluster: 48 | strategy: 49 | matrix: 50 | model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b', 'llama-3.3-70b'] 51 | fail-fast: false 52 | uses: ./.github/workflows/bench_job.yml 53 | with: 54 | config: '{"M4PRO_GPU16_24GB": 3}' 55 | model: ${{ matrix.model }} 56 | calling_job_name: 'three-m4-pro-cluster' 57 | network_interface: 'Ethernet' 58 | secrets: inherit 59 | 60 | # test-m3-single-node: 61 | # strategy: 62 | # matrix: 63 | # model: ['llama-3.2-1b'] 64 | # fail-fast: false 65 | # uses: ./.github/workflows/bench_job.yml 66 | # with: 67 | # config: '{"M3MAX_GPU40_128GB": 1}' 68 | # model: ${{ matrix.model }} 69 | # calling_job_name: 'test-m3-cluster' 70 | # network_interface: 'Ethernet' 71 | # secrets: inherit -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .venv* 3 | test_weights.npz 4 | .exo_used_ports 5 | .exo_node_id 6 | .idea 7 | .DS_Store 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | /.Python 19 | /develop-eggs/ 20 | /dist/ 21 | /downloads/ 22 | /eggs/ 23 | /.eggs/ 24 | /lib/ 25 | /lib64/ 26 | /parts/ 27 | /sdist/ 28 | /var/ 29 | /wheels/ 30 | /share/python-wheels/ 31 | /*.egg-info/ 32 | /.installed.cfg 33 | /*.egg 34 | /MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | Untitled.ipynb 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 118 | .pdm.toml 119 | .pdm-python 120 | .pdm-build/ 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # PyCharm 166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 168 | # and can be added to the global gitignore or merged into this file. For a more nuclear 169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 170 | #.idea/ 171 | 172 | **/*.xcodeproj/* 173 | .aider* 174 | 175 | exo/tinychat/images/*.png 176 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | indent_width = 2 4 | column_limit = 200 5 | allow_split_before_dict_value = False 6 | dedent_closing_brackets = True 7 | split_before_first_argument = False 8 | split_complex_comprehension = False 9 | continuation_indent_width = 2 10 | indent_dictionary_value = True 11 | allow_multiline_dictionary_keys = True 12 | each_dict_entry_on_separate_line = False 13 | allow_multiline_lambdas = True 14 | blank_line_before_nested_class_or_def = False 15 | arithmetic_precedence_indication = True 16 | no_spaces_around_selected_binary_operators = "*,/" 17 | coalesce_brackets = True 18 | space_between_ending_comma_and_closing_bracket = False 19 | split_before_expression_after_opening_paren = False -------------------------------------------------------------------------------- /configure_mlx.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Get the total memory in MB 4 | TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024)) 5 | 6 | # Calculate 80% and TOTAL_MEM_GB-5GB in MB 7 | EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100)) 8 | MINUS_5GB=$((($TOTAL_MEM_MB - 5120))) 9 | 10 | # Calculate 70% and TOTAL_MEM_GB-8GB in MB 11 | SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100)) 12 | MINUS_8GB=$((($TOTAL_MEM_MB - 8192))) 13 | 14 | # Set WIRED_LIMIT_MB to higher value 15 | if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then 16 | WIRED_LIMIT_MB=$EIGHTY_PERCENT 17 | else 18 | WIRED_LIMIT_MB=$MINUS_5GB 19 | fi 20 | 21 | # Set WIRED_LWM_MB to higher value 22 | if [ $SEVENTY_PERCENT -gt $MINUS_8GB ]; then 23 | WIRED_LWM_MB=$SEVENTY_PERCENT 24 | else 25 | WIRED_LWM_MB=$MINUS_8GB 26 | fi 27 | 28 | # Display the calculated values 29 | echo "Total memory: $TOTAL_MEM_MB MB" 30 | echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB" 31 | echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB" 32 | 33 | # Apply the values with sysctl, but check if we're already root 34 | if [ "$EUID" -eq 0 ]; then 35 | sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 36 | sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 37 | else 38 | # Try without sudo first, fall back to sudo if needed 39 | sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 2>/dev/null || \ 40 | sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 41 | sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \ 42 | sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 43 | fi -------------------------------------------------------------------------------- /docs/exo-logo-black-bg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/docs/exo-logo-black-bg.jpg -------------------------------------------------------------------------------- /docs/exo-logo-transparent-black-text.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1c6f0b66b68ffc11a42cf25fbd43a6fbea99869ed4ba82e5f480d8213e9b7061 3 | size 1296 4 | -------------------------------------------------------------------------------- /docs/exo-logo-transparent.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c7aeca6a876a195df706f3221f1bfd4792884e6042c2b355026f94cba0f7576d 3 | size 1296 4 | -------------------------------------------------------------------------------- /docs/exo-rounded.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1316a53899f32ba6c33b083fca232b638aea4efbcf36bc99e640369169e6a1c9 3 | size 28651 4 | -------------------------------------------------------------------------------- /docs/exo-screenshot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/docs/exo-screenshot.jpg -------------------------------------------------------------------------------- /examples/astra/README.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | 3 | This example app is an open-source alternative to [Google's Project Astra](https://deepmind.google/technologies/gemini/project-astra/). It leverages the exo library to run on your own devices, providing a fully transparent and customizable experience compared to Google's closed-source API. 4 | -------------------------------------------------------------------------------- /examples/astra/astra.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /examples/astra/astra.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | IDEDidComputeMac32BitWarning 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /examples/astra/astra.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "originHash" : "8f61689e55c5551e76f2c686d145061dc1fa621a58cbca576565ebfabc15c894", 3 | "pins" : [ 4 | { 5 | "identity" : "swift-argument-parser", 6 | "kind" : "remoteSourceControl", 7 | "location" : "https://github.com/apple/swift-argument-parser.git", 8 | "state" : { 9 | "revision" : "c8ed701b513cf5177118a175d85fbbbcd707ab41", 10 | "version" : "1.3.0" 11 | } 12 | }, 13 | { 14 | "identity" : "swift-transformers", 15 | "kind" : "remoteSourceControl", 16 | "location" : "https://github.com/huggingface/swift-transformers.git", 17 | "state" : { 18 | "revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe", 19 | "version" : "0.1.7" 20 | } 21 | }, 22 | { 23 | "identity" : "whisperkit", 24 | "kind" : "remoteSourceControl", 25 | "location" : "https://github.com/argmaxinc/whisperkit", 26 | "state" : { 27 | "branch" : "main", 28 | "revision" : "59aaa4e5f211622f9a5e133440220d9974641d3b" 29 | } 30 | } 31 | ], 32 | "version" : 3 33 | } 34 | -------------------------------------------------------------------------------- /examples/astra/astra/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /examples/astra/astra/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "idiom" : "mac", 10 | "scale" : "1x", 11 | "size" : "16x16" 12 | }, 13 | { 14 | "idiom" : "mac", 15 | "scale" : "2x", 16 | "size" : "16x16" 17 | }, 18 | { 19 | "idiom" : "mac", 20 | "scale" : "1x", 21 | "size" : "32x32" 22 | }, 23 | { 24 | "idiom" : "mac", 25 | "scale" : "2x", 26 | "size" : "32x32" 27 | }, 28 | { 29 | "idiom" : "mac", 30 | "scale" : "1x", 31 | "size" : "128x128" 32 | }, 33 | { 34 | "idiom" : "mac", 35 | "scale" : "2x", 36 | "size" : "128x128" 37 | }, 38 | { 39 | "idiom" : "mac", 40 | "scale" : "1x", 41 | "size" : "256x256" 42 | }, 43 | { 44 | "idiom" : "mac", 45 | "scale" : "2x", 46 | "size" : "256x256" 47 | }, 48 | { 49 | "idiom" : "mac", 50 | "scale" : "1x", 51 | "size" : "512x512" 52 | }, 53 | { 54 | "idiom" : "mac", 55 | "scale" : "2x", 56 | "size" : "512x512" 57 | } 58 | ], 59 | "info" : { 60 | "author" : "xcode", 61 | "version" : 1 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /examples/astra/astra/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /examples/astra/astra/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /examples/astra/astra/astra.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.device.audio-input 10 | 11 | com.apple.security.files.downloads.read-only 12 | 13 | com.apple.security.files.user-selected.read-write 14 | 15 | com.apple.security.network.client 16 | 17 | com.apple.security.network.server 18 | 19 | com.apple.security.device.camera 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /examples/astra/astra/astraApp.swift: -------------------------------------------------------------------------------- 1 | // 2 | // astraApp.swift 3 | // astra 4 | // 5 | // Created by Alex on 18/08/2024. 6 | // 7 | 8 | import SwiftUI 9 | 10 | @main 11 | struct astraApp: App { 12 | var body: some Scene { 13 | WindowGroup { 14 | ContentView() 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /examples/astra/astraTests/astraTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // astraTests.swift 3 | // astraTests 4 | // 5 | // Created by Alex on 18/08/2024. 6 | // 7 | 8 | import XCTest 9 | 10 | final class astraTests: XCTestCase { 11 | 12 | override func setUpWithError() throws { 13 | // Put setup code here. This method is called before the invocation of each test method in the class. 14 | } 15 | 16 | override func tearDownWithError() throws { 17 | // Put teardown code here. This method is called after the invocation of each test method in the class. 18 | } 19 | 20 | func testExample() throws { 21 | // This is an example of a functional test case. 22 | // Use XCTAssert and related functions to verify your tests produce the correct results. 23 | // Any test you write for XCTest can be annotated as throws and async. 24 | // Mark your test throws to produce an unexpected failure when your test encounters an uncaught error. 25 | // Mark your test async to allow awaiting for asynchronous code to complete. Check the results with assertions afterwards. 26 | } 27 | 28 | func testPerformanceExample() throws { 29 | // This is an example of a performance test case. 30 | measure { 31 | // Put the code you want to measure the time of here. 32 | } 33 | } 34 | 35 | } 36 | -------------------------------------------------------------------------------- /examples/astra/astraUITests/astraUITests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // astraUITests.swift 3 | // astraUITests 4 | // 5 | // Created by Alex on 18/08/2024. 6 | // 7 | 8 | import XCTest 9 | 10 | final class astraUITests: XCTestCase { 11 | 12 | override func setUpWithError() throws { 13 | // Put setup code here. This method is called before the invocation of each test method in the class. 14 | 15 | // In UI tests it is usually best to stop immediately when a failure occurs. 16 | continueAfterFailure = false 17 | 18 | // In UI tests it’s important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this. 19 | } 20 | 21 | override func tearDownWithError() throws { 22 | // Put teardown code here. This method is called after the invocation of each test method in the class. 23 | } 24 | 25 | func testExample() throws { 26 | // UI tests must launch the application that they test. 27 | let app = XCUIApplication() 28 | app.launch() 29 | 30 | // Use XCTAssert and related functions to verify your tests produce the correct results. 31 | } 32 | 33 | func testLaunchPerformance() throws { 34 | if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 7.0, *) { 35 | // This measures how long it takes to launch your application. 36 | measure(metrics: [XCTApplicationLaunchMetric()]) { 37 | XCUIApplication().launch() 38 | } 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /examples/astra/astraUITests/astraUITestsLaunchTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // astraUITestsLaunchTests.swift 3 | // astraUITests 4 | // 5 | // Created by Alex on 18/08/2024. 6 | // 7 | 8 | import XCTest 9 | 10 | final class astraUITestsLaunchTests: XCTestCase { 11 | 12 | override class var runsForEachTargetApplicationUIConfiguration: Bool { 13 | true 14 | } 15 | 16 | override func setUpWithError() throws { 17 | continueAfterFailure = false 18 | } 19 | 20 | func testLaunch() throws { 21 | let app = XCUIApplication() 22 | app.launch() 23 | 24 | // Insert steps here to perform after app launch but before taking a screenshot, 25 | // such as logging into a test account or navigating somewhere in the app 26 | 27 | let attachment = XCTAttachment(screenshot: app.screenshot()) 28 | attachment.name = "Launch Screen" 29 | attachment.lifetime = .keepAlways 30 | add(attachment) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /examples/chatgpt_api.sh: -------------------------------------------------------------------------------- 1 | # exo provides an API that aims to be a drop-in replacements for the ChatGPT-API. 2 | # This example shows how you can use the API first without streaming and second with streaming. 3 | # This works the same in a single-node set up and in a multi-node setup. 4 | # You need to start exo before running this by running `python3 main.py`. 5 | 6 | API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}" 7 | MODEL="llama-3.1-8b" 8 | PROMPT="What is the meaning of exo?" 9 | TEMPERATURE=0.7 10 | 11 | echo "" 12 | echo "" 13 | echo "--- Output without streaming:" 14 | echo "" 15 | curl "${API_ENDPOINT}/v1/chat/completions" --silent \ 16 | -H "Content-Type: application/json" \ 17 | -d '{ 18 | "model": "'"${MODEL}"'", 19 | "messages": [{"role": "user", "content": "'"${PROMPT}"'"}], 20 | "temperature": '"${TEMPERATURE}"' 21 | }' 22 | 23 | echo "" 24 | echo "" 25 | echo "--- Output with streaming:" 26 | echo "" 27 | curl "${API_ENDPOINT}/v1/chat/completions" --silent \ 28 | -H "Content-Type: application/json" \ 29 | -d '{ 30 | "model": "'"${MODEL}"'", 31 | "messages": [{"role": "user", "content": "'"${PROMPT}"'"}], 32 | "temperature": '"${TEMPERATURE}"', 33 | "stream": true 34 | }' | while read -r line; do 35 | if [[ $line == data:* ]]; then 36 | content=$(echo "$line" | sed 's/^data: //') 37 | echo "$content" | jq -r '.choices[].delta.content' --unbuffered | tr -d '\n' 38 | fi 39 | done -------------------------------------------------------------------------------- /examples/function_calling.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import requests 4 | 5 | def get_current_weather(location: str, unit: str = "celsius"): 6 | """Mock weather data function""" 7 | # Hardcoded response for demo purposes 8 | return { 9 | "location": location, 10 | "temperature": 22 if unit == "celsius" else 72, 11 | "unit": unit, 12 | "forecast": "Sunny with light clouds" 13 | } 14 | 15 | def try_parse_tool_calls(content: str): 16 | """Try parse the tool calls.""" 17 | tool_calls = [] 18 | offset = 0 19 | for i, m in enumerate(re.finditer(r"\n(.+)?\n", content)): 20 | if i == 0: 21 | offset = m.start() 22 | try: 23 | func = json.loads(m.group(1)) 24 | tool_calls.append({"type": "function", "function": func}) 25 | if isinstance(func["arguments"], str): 26 | func["arguments"] = json.loads(func["arguments"]) 27 | except json.JSONDecodeError as e: 28 | print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}") 29 | pass 30 | if tool_calls: 31 | if offset > 0 and content[:offset].strip(): 32 | c = content[:offset] 33 | else: 34 | c = "" 35 | return {"role": "assistant", "content": c, "tool_calls": tool_calls} 36 | return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)} 37 | 38 | def chat_completion(messages): 39 | """Send chat completion request to local server""" 40 | response = requests.post( 41 | "http://localhost:52415/v1/chat/completions", 42 | json={ 43 | "model": "qwen-2.5-1.5b", 44 | "messages": messages, 45 | "tools": [{ 46 | "type": "function", 47 | "function": { 48 | "name": "get_current_weather", 49 | "description": "Get the current weather in a given location", 50 | "parameters": { 51 | "type": "object", 52 | "properties": { 53 | "location": { 54 | "type": "string", 55 | "description": "The city and state, e.g. San Francisco, CA" 56 | }, 57 | "unit": { 58 | "type": "string", 59 | "enum": ["celsius", "fahrenheit"] 60 | } 61 | }, 62 | "required": ["location"] 63 | } 64 | } 65 | }], 66 | "tool_choice": "auto" 67 | } 68 | ) 69 | return response.json() 70 | 71 | def main(): 72 | # Initial conversation 73 | messages = [{ 74 | "role": "user", 75 | "content": "Hi there, what's the weather in Boston?" 76 | }] 77 | 78 | # Get initial response 79 | response = chat_completion(messages) 80 | print(f"First response: {response}") 81 | assistant_message = try_parse_tool_calls(response["choices"][0]["message"]["content"]) 82 | messages.append(assistant_message) 83 | 84 | # If there are tool calls, execute them and continue conversation 85 | if "tool_calls" in assistant_message: 86 | for tool_call in assistant_message["tool_calls"]: 87 | if tool_call["function"]["name"] == "get_current_weather": 88 | args = tool_call["function"]["arguments"] 89 | weather_data = get_current_weather(**args) 90 | 91 | # Add tool response to messages 92 | messages.append({ 93 | "role": "tool", 94 | "content": json.dumps(weather_data), 95 | "name": tool_call["function"]["name"] 96 | }) 97 | 98 | # Get final response with weather data 99 | response = chat_completion(messages) 100 | print(f"Final response: {response}") 101 | messages.append({ 102 | "role": "assistant", 103 | "content": response["choices"][0]["message"]["content"] 104 | }) 105 | 106 | # Print full conversation 107 | for msg in messages: 108 | print(f"\n{msg['role'].upper()}: {msg['content']}") 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /exo/__init__.py: -------------------------------------------------------------------------------- 1 | from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION -------------------------------------------------------------------------------- /exo/api/__init__.py: -------------------------------------------------------------------------------- 1 | from exo.api.chatgpt_api import ChatGPTAPI as ChatGPTAPI 2 | -------------------------------------------------------------------------------- /exo/apputil/__init__.py: -------------------------------------------------------------------------------- 1 | from exo.apputil.anim import create_animation_mp4 -------------------------------------------------------------------------------- /exo/apputil/anim.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw, ImageFont, ImageFilter 2 | import os 3 | import numpy as np 4 | import cv2 5 | import sys 6 | 7 | def draw_rounded_rectangle(draw, coords, radius, fill): 8 | left, top, right, bottom = coords 9 | diameter = radius * 2 10 | draw.rectangle([left + radius, top, right - radius, bottom], fill=fill) 11 | draw.rectangle([left, top + radius, right, bottom - radius], fill=fill) 12 | draw.pieslice([left, top, left + diameter, top + diameter], 180, 270, fill=fill) 13 | draw.pieslice([right - diameter, top, right, top + diameter], 270, 360, fill=fill) 14 | draw.pieslice([left, bottom - diameter, left + diameter, bottom], 90, 180, fill=fill) 15 | draw.pieslice([right - diameter, bottom - diameter, right, bottom], 0, 90, fill=fill) 16 | 17 | def draw_centered_text_rounded(draw, text, font, rect_coords, radius=10, text_color="yellow", bg_color=(43,33,44)): 18 | bbox = font.getbbox(text) 19 | text_width = bbox[2] - bbox[0] 20 | text_height = bbox[3] - bbox[1] 21 | rect_left, rect_top, rect_right, rect_bottom = rect_coords 22 | rect_width = rect_right - rect_left 23 | rect_height = rect_bottom - rect_top 24 | text_x = rect_left + (rect_width - text_width) // 2 25 | text_y = rect_top + (rect_height - text_height) // 2 26 | draw_rounded_rectangle(draw, rect_coords, radius, bg_color) 27 | draw.text((text_x, text_y), text, fill=text_color, font=font) 28 | 29 | def draw_left_aligned_text_rounded(draw, text, font, rect_coords, padding_left=20, radius=10, text_color="yellow", bg_color=(43,33,44)): 30 | bbox = font.getbbox(text) 31 | text_height = bbox[3] - bbox[1] 32 | rect_left, rect_top, rect_right, rect_bottom = rect_coords 33 | rect_height = rect_bottom - rect_top 34 | text_y = rect_top + (rect_height - text_height) // 2 35 | text_x = rect_left + padding_left 36 | draw_rounded_rectangle(draw, rect_coords, radius, bg_color) 37 | draw.text((text_x, text_y), text, fill=text_color, font=font) 38 | 39 | def draw_right_text_dynamic_width_rounded(draw, text, font, base_coords, padding=20, radius=10, text_color="yellow", bg_color=(43,33,44)): 40 | bbox = font.getbbox(text) 41 | text_width = bbox[2] - bbox[0] 42 | text_height = bbox[3] - bbox[1] 43 | _, rect_top, rect_right, rect_bottom = base_coords 44 | rect_height = rect_bottom - rect_top 45 | new_rect_left = rect_right - (text_width + (padding * 2)) 46 | text_y = rect_top + (rect_height - text_height) // 2 47 | text_x = new_rect_left + padding 48 | draw_rounded_rectangle(draw, (new_rect_left, rect_top, rect_right, rect_bottom), radius, bg_color) 49 | draw.text((text_x, text_y), text, fill=text_color, font=font) 50 | return new_rect_left 51 | 52 | def draw_progress_bar(draw, progress, coords, color="yellow", bg_color=(70, 70, 70)): 53 | left, top, right, bottom = coords 54 | total_width = right - left 55 | draw.rectangle(coords, fill=bg_color) 56 | progress_width = int(total_width * progress) 57 | if progress_width > 0: 58 | draw.rectangle((left, top, left + progress_width, bottom), fill=color) 59 | 60 | def crop_image(image, top_crop=70): 61 | width, height = image.size 62 | return image.crop((0, top_crop, width, height)) 63 | 64 | def create_animation_mp4( 65 | replacement_image_path, 66 | output_path, 67 | device_name, 68 | prompt_text, 69 | fps=30, 70 | target_size=(512, 512), 71 | target_position=(139, 755), 72 | progress_coords=(139, 1285, 655, 1295), 73 | device_coords=(1240, 370, 1640, 416), 74 | prompt_coords=(332, 1702, 2662, 1745) 75 | ): 76 | frames = [] 77 | try: 78 | font = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 20) 79 | promptfont = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 24) 80 | except: 81 | font = ImageFont.load_default() 82 | promptfont = ImageFont.load_default() 83 | 84 | # Get the base directory for images when running as a bundled app 85 | if hasattr(sys, '_MEIPASS'): 86 | base_dir = os.path.join(sys._MEIPASS, "exo", "apputil", "baseimages") 87 | else: 88 | base_dir = os.path.join(os.path.dirname(__file__), "baseimages") 89 | 90 | # Process first frame 91 | base_img = Image.open(os.path.join(base_dir, "image1.png")) 92 | draw = ImageDraw.Draw(base_img) 93 | draw_centered_text_rounded(draw, device_name, font, device_coords) 94 | frames.extend([crop_image(base_img)] * 30) # 1 second at 30fps 95 | 96 | # Process second frame with typing animation 97 | base_img2 = Image.open(os.path.join(base_dir, "image2.png")) 98 | for i in range(len(prompt_text) + 1): 99 | current_frame = base_img2.copy() 100 | draw = ImageDraw.Draw(current_frame) 101 | draw_centered_text_rounded(draw, device_name, font, device_coords) 102 | if i > 0: # Only draw if we have at least one character 103 | draw_left_aligned_text_rounded(draw, prompt_text[:i], promptfont, prompt_coords) 104 | frames.extend([crop_image(current_frame)] * 2) # 2 frames per character for smooth typing 105 | 106 | # Hold the complete prompt for a moment 107 | frames.extend([frames[-1]] * 30) # Hold for 1 second 108 | 109 | # Create blur sequence 110 | replacement_img = Image.open(replacement_image_path) 111 | base_img = Image.open(os.path.join(base_dir, "image3.png")) 112 | blur_steps = [int(80 * (1 - i/8)) for i in range(9)] 113 | 114 | for i, blur_amount in enumerate(blur_steps): 115 | new_frame = base_img.copy() 116 | draw = ImageDraw.Draw(new_frame) 117 | 118 | replacement_copy = replacement_img.copy() 119 | replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS) 120 | if blur_amount > 0: 121 | replacement_copy = replacement_copy.filter(ImageFilter.GaussianBlur(radius=blur_amount)) 122 | 123 | mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None 124 | new_frame.paste(replacement_copy, target_position, mask) 125 | 126 | draw_progress_bar(draw, (i + 1) / 9, progress_coords) 127 | draw_centered_text_rounded(draw, device_name, font, device_coords) 128 | draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30) 129 | 130 | frames.extend([crop_image(new_frame)] * 15) # 0.5 seconds at 30fps 131 | 132 | # Create and add final frame (image4) 133 | final_base = Image.open(os.path.join(base_dir, "image4.png")) 134 | draw = ImageDraw.Draw(final_base) 135 | 136 | draw_centered_text_rounded(draw, device_name, font, device_coords) 137 | draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30) 138 | 139 | replacement_copy = replacement_img.copy() 140 | replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS) 141 | mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None 142 | final_base.paste(replacement_copy, target_position, mask) 143 | 144 | frames.extend([crop_image(final_base)] * 30) # 1 second at 30fps 145 | 146 | # Convert frames to video using H.264 codec 147 | if frames: 148 | first_frame = np.array(frames[0]) 149 | height, width = first_frame.shape[:2] 150 | fourcc = cv2.VideoWriter_fourcc(*'avc1') 151 | out = cv2.VideoWriter( 152 | output_path, 153 | fourcc, 154 | fps, 155 | (width, height), 156 | isColor=True 157 | ) 158 | 159 | if not out.isOpened(): 160 | print("Error: VideoWriter failed to open") 161 | return 162 | 163 | for frame in frames: 164 | frame_array = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR) 165 | out.write(frame_array) 166 | 167 | out.release() 168 | print(f"Video saved successfully to {output_path}") 169 | -------------------------------------------------------------------------------- /exo/apputil/baseimages/image1.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:361fdadd67c277d45cd18b0bfc8c5ceea5fd89f2d65aef157fd915ce9cbb8599 3 | size 814460 4 | -------------------------------------------------------------------------------- /exo/apputil/baseimages/image2.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f0e3891bc6b4f4dfa7444af53fcaa4b3ba06b0549546202be3243f08a0e6bd7e 3 | size 814235 4 | -------------------------------------------------------------------------------- /exo/apputil/baseimages/image3.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a2dc5b3378aef397d60fd1252da8a1c578ad97e202a859590ffa416b49551d19 3 | size 146633 4 | -------------------------------------------------------------------------------- /exo/apputil/baseimages/image4.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:dbc6883e2a3c5233ec7b844c98646922bdc4f5e42e1f424857eaff56f785dbcd 3 | size 668550 4 | -------------------------------------------------------------------------------- /exo/download/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/download/__init__.py -------------------------------------------------------------------------------- /exo/download/download_progress.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Callable, Coroutine, Any, Literal 2 | from exo.inference.shard import Shard 3 | from dataclasses import dataclass 4 | from datetime import timedelta 5 | 6 | 7 | @dataclass 8 | class RepoFileProgressEvent: 9 | repo_id: str 10 | repo_revision: str 11 | file_path: str 12 | downloaded: int 13 | downloaded_this_session: int 14 | total: int 15 | speed: int 16 | eta: timedelta 17 | status: Literal["not_started", "in_progress", "complete"] 18 | start_time: float 19 | 20 | def to_dict(self): 21 | return { 22 | "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session, 23 | "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status, "start_time": self.start_time 24 | } 25 | 26 | @classmethod 27 | def from_dict(cls, data): 28 | if 'eta' in data: data['eta'] = timedelta(seconds=data['eta']) 29 | return cls(**data) 30 | 31 | 32 | @dataclass 33 | class RepoProgressEvent: 34 | shard: Shard 35 | repo_id: str 36 | repo_revision: str 37 | completed_files: int 38 | total_files: int 39 | downloaded_bytes: int 40 | downloaded_bytes_this_session: int 41 | total_bytes: int 42 | overall_speed: int 43 | overall_eta: timedelta 44 | file_progress: Dict[str, RepoFileProgressEvent] 45 | status: Literal["not_started", "in_progress", "complete"] 46 | 47 | def to_dict(self): 48 | return { 49 | "shard": self.shard.to_dict(), "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes, 50 | "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(), 51 | "file_progress": {k: v.to_dict() 52 | for k, v in self.file_progress.items()}, "status": self.status 53 | } 54 | 55 | @classmethod 56 | def from_dict(cls, data): 57 | if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta']) 58 | if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()} 59 | if 'shard' in data: data['shard'] = Shard.from_dict(data['shard']) 60 | 61 | return cls(**data) 62 | 63 | 64 | RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]] 65 | RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]] 66 | -------------------------------------------------------------------------------- /exo/download/hf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/download/hf/__init__.py -------------------------------------------------------------------------------- /exo/download/hf/hf_helpers.py: -------------------------------------------------------------------------------- 1 | import aiofiles.os as aios 2 | from typing import Union 3 | import os 4 | from typing import Callable, Optional, Dict, List, Union 5 | from fnmatch import fnmatch 6 | from pathlib import Path 7 | from typing import Generator, Iterable, TypeVar 8 | from exo.helpers import DEBUG 9 | from exo.inference.shard import Shard 10 | import aiofiles 11 | 12 | T = TypeVar("T") 13 | 14 | def filter_repo_objects( 15 | items: Iterable[T], 16 | *, 17 | allow_patterns: Optional[Union[List[str], str]] = None, 18 | ignore_patterns: Optional[Union[List[str], str]] = None, 19 | key: Optional[Callable[[T], str]] = None, 20 | ) -> Generator[T, None, None]: 21 | if isinstance(allow_patterns, str): 22 | allow_patterns = [allow_patterns] 23 | if isinstance(ignore_patterns, str): 24 | ignore_patterns = [ignore_patterns] 25 | if allow_patterns is not None: 26 | allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns] 27 | if ignore_patterns is not None: 28 | ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns] 29 | 30 | if key is None: 31 | def _identity(item: T) -> str: 32 | if isinstance(item, str): 33 | return item 34 | if isinstance(item, Path): 35 | return str(item) 36 | raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.") 37 | key = _identity 38 | 39 | for item in items: 40 | path = key(item) 41 | if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns): 42 | continue 43 | if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns): 44 | continue 45 | yield item 46 | 47 | def _add_wildcard_to_directories(pattern: str) -> str: 48 | if pattern[-1] == "/": 49 | return pattern + "*" 50 | return pattern 51 | 52 | def get_hf_endpoint() -> str: 53 | return os.environ.get('HF_ENDPOINT', "https://huggingface.co") 54 | 55 | def get_hf_home() -> Path: 56 | """Get the Hugging Face home directory.""" 57 | return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface")) 58 | 59 | async def get_hf_token(): 60 | """Retrieve the Hugging Face token from the user's HF_HOME directory.""" 61 | token_path = get_hf_home()/"token" 62 | if await aios.path.exists(token_path): 63 | async with aiofiles.open(token_path, 'r') as f: 64 | return (await f.read()).strip() 65 | return None 66 | 67 | async def get_auth_headers(): 68 | """Get authentication headers if a token is available.""" 69 | token = await get_hf_token() 70 | if token: 71 | return {"Authorization": f"Bearer {token}"} 72 | return {} 73 | 74 | def extract_layer_num(tensor_name: str) -> Optional[int]: 75 | # This is a simple example and might need to be adjusted based on the actual naming convention 76 | parts = tensor_name.split('.') 77 | for part in parts: 78 | if part.isdigit(): 79 | return int(part) 80 | return None 81 | 82 | def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]: 83 | default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"]) 84 | shard_specific_patterns = set() 85 | if weight_map: 86 | for tensor_name, filename in weight_map.items(): 87 | layer_num = extract_layer_num(tensor_name) 88 | if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer: 89 | shard_specific_patterns.add(filename) 90 | sorted_file_names = sorted(weight_map.values()) 91 | if shard.is_first_layer(): 92 | shard_specific_patterns.add(sorted_file_names[0]) 93 | elif shard.is_last_layer(): 94 | shard_specific_patterns.add(sorted_file_names[-1]) 95 | else: 96 | shard_specific_patterns = set(["*.safetensors"]) 97 | if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}") 98 | return list(default_patterns | shard_specific_patterns) 99 | -------------------------------------------------------------------------------- /exo/download/shard_download.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Tuple, Dict, AsyncIterator 3 | from pathlib import Path 4 | from exo.inference.shard import Shard 5 | from exo.download.download_progress import RepoProgressEvent 6 | from exo.helpers import AsyncCallbackSystem 7 | 8 | 9 | class ShardDownloader(ABC): 10 | @abstractmethod 11 | async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: 12 | """ 13 | Ensures that the shard is downloaded. 14 | Does not allow multiple overlapping downloads at once. 15 | If you try to download a Shard which overlaps a Shard that is already being downloaded, 16 | the download will be cancelled and a new download will start. 17 | 18 | Args: 19 | shard (Shard): The shard to download. 20 | inference_engine_name (str): The inference engine used on the node hosting the shard 21 | """ 22 | pass 23 | 24 | @property 25 | @abstractmethod 26 | def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: 27 | pass 28 | 29 | @abstractmethod 30 | async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]: 31 | """Get the download status of shards. 32 | 33 | Returns: 34 | Optional[Dict[str, float]]: A dictionary mapping shard IDs to their download percentage (0-100), 35 | or None if status cannot be determined 36 | """ 37 | pass 38 | 39 | 40 | class NoopShardDownloader(ShardDownloader): 41 | async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: 42 | return Path("/tmp/noop_shard") 43 | 44 | @property 45 | def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: 46 | return AsyncCallbackSystem() 47 | 48 | async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]: 49 | if False: yield 50 | -------------------------------------------------------------------------------- /exo/download/test_new_shard_download.py: -------------------------------------------------------------------------------- 1 | from exo.download.new_shard_download import NewShardDownloader 2 | from exo.inference.shard import Shard 3 | import asyncio 4 | 5 | async def test_new_shard_download(): 6 | shard_downloader = NewShardDownloader() 7 | shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event)) 8 | await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine") 9 | async for path, shard_status in shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine"): 10 | print("Shard download status:", path, shard_status) 11 | 12 | if __name__ == "__main__": 13 | asyncio.run(test_new_shard_download()) 14 | 15 | -------------------------------------------------------------------------------- /exo/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/inference/__init__.py -------------------------------------------------------------------------------- /exo/inference/debug_inference_engine.py: -------------------------------------------------------------------------------- 1 | from exo.inference.inference_engine import InferenceEngine 2 | from exo.inference.shard import Shard 3 | from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine 4 | import asyncio 5 | import numpy as np 6 | 7 | 8 | # An inference engine should work the same for any number of Shards, as long as the Shards are continuous. 9 | async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str): 10 | from exo.inference.tinygrad.inference import Tokenizer 11 | from pathlib import Path 12 | 13 | _tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model")) 14 | 15 | prompt = "In a single word only, what is the last name of the president of the United States? " 16 | resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt) 17 | token_full = await inference_engine_1.sample(resp_full) 18 | 19 | next_resp_full, _ = await inference_engine_1.infer_tensor( 20 | "A", 21 | shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), 22 | input_data=token_full, 23 | ) 24 | 25 | resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt) 26 | resp2, _ = await inference_engine_2.infer_tensor( 27 | "B", 28 | shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), 29 | input_data=resp1, 30 | ) 31 | token2 = await inference_engine_2.sample(resp2) 32 | resp3, _ = await inference_engine_1.infer_tensor( 33 | "B", 34 | shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), 35 | input_data=token2, 36 | ) 37 | resp4, _ = await inference_engine_2.infer_tensor( 38 | "B", 39 | shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), 40 | input_data=resp3, 41 | ) 42 | 43 | print(f"{resp2=}") 44 | print(f"full: {_tokenizer.decode(resp_full)}") 45 | print(f"next full: {_tokenizer.decode(next_resp_full)}") 46 | print(f"resp2: {_tokenizer.decode(resp2)}") 47 | print(f"{resp4=}") 48 | print(f"resp4: {_tokenizer.decode(resp4)}") 49 | 50 | assert np.array_equal(resp_full, resp2) 51 | assert np.array_equal(next_resp_full, resp4) 52 | 53 | 54 | asyncio.run(test_inference_engine( 55 | TinygradDynamicShardInferenceEngine(), 56 | TinygradDynamicShardInferenceEngine(), 57 | "llama3-8b-sfr", 58 | )) 59 | -------------------------------------------------------------------------------- /exo/inference/dummy_inference_engine.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, TYPE_CHECKING 2 | import numpy as np 3 | from exo.inference.inference_engine import InferenceEngine 4 | from exo.inference.shard import Shard 5 | from exo.inference.tokenizers import DummyTokenizer 6 | 7 | class DummyInferenceEngine(InferenceEngine): 8 | def __init__(self): 9 | self.shard = None 10 | self.vocab_size = 1000 11 | self.hidden_size = 256 12 | self.eos_token_id = 0 13 | self.latency_mean = 0.1 14 | self.latency_stddev = 0.02 15 | self.num_generate_dummy_tokens = 10 16 | self.tokenizer = DummyTokenizer() 17 | 18 | async def encode(self, shard: Shard, prompt: str) -> np.ndarray: 19 | return np.array(self.tokenizer.encode(prompt)) 20 | 21 | async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray: 22 | if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id]) 23 | return x 24 | 25 | async def decode(self, shard: Shard, tokens: np.ndarray) -> str: 26 | return self.tokenizer.decode(tokens) 27 | 28 | async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]: 29 | await self.ensure_shard(shard) 30 | return input_data + 1 if self.shard.is_last_layer() else input_data, None 31 | 32 | async def ensure_shard(self, shard: Shard): 33 | if self.shard == shard: return 34 | self.shard = shard 35 | 36 | async def load_checkpoint(self, shard: Shard, path: str): 37 | await self.ensure_shard(shard) 38 | -------------------------------------------------------------------------------- /exo/inference/inference_engine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from exo.helpers import DEBUG # Make sure to import DEBUG 4 | 5 | from typing import Tuple, Optional 6 | from abc import ABC, abstractmethod 7 | from .shard import Shard 8 | from exo.download.shard_download import ShardDownloader 9 | 10 | 11 | class InferenceEngine(ABC): 12 | session = {} 13 | 14 | @abstractmethod 15 | async def encode(self, shard: Shard, prompt: str) -> np.ndarray: 16 | pass 17 | 18 | @abstractmethod 19 | async def sample(self, x: np.ndarray) -> np.ndarray: 20 | pass 21 | 22 | @abstractmethod 23 | async def decode(self, shard: Shard, tokens: np.ndarray) -> str: 24 | pass 25 | 26 | @abstractmethod 27 | async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]: 28 | pass 29 | 30 | @abstractmethod 31 | async def load_checkpoint(self, shard: Shard, path: str): 32 | pass 33 | 34 | async def save_checkpoint(self, shard: Shard, path: str): 35 | pass 36 | 37 | async def save_session(self, key, value): 38 | self.session[key] = value 39 | 40 | async def clear_session(self): 41 | self.session.empty() 42 | 43 | async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]: 44 | tokens = await self.encode(shard, prompt) 45 | if shard.model_id != 'stable-diffusion-2-1-base': 46 | x = tokens.reshape(1, -1) 47 | else: 48 | x = tokens 49 | output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state) 50 | 51 | return output_data, inference_state 52 | 53 | 54 | inference_engine_classes = { 55 | "mlx": "MLXDynamicShardInferenceEngine", 56 | "tinygrad": "TinygradDynamicShardInferenceEngine", 57 | "dummy": "DummyInferenceEngine", 58 | } 59 | 60 | 61 | def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader): 62 | if DEBUG >= 2: 63 | print(f"get_inference_engine called with: {inference_engine_name}") 64 | if inference_engine_name == "mlx": 65 | from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine 66 | 67 | return MLXDynamicShardInferenceEngine(shard_downloader) 68 | elif inference_engine_name == "tinygrad": 69 | from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine 70 | import tinygrad.helpers 71 | tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) 72 | 73 | return TinygradDynamicShardInferenceEngine(shard_downloader) 74 | elif inference_engine_name == "dummy": 75 | from exo.inference.dummy_inference_engine import DummyInferenceEngine 76 | return DummyInferenceEngine() 77 | raise ValueError(f"Unsupported inference engine: {inference_engine_name}") 78 | -------------------------------------------------------------------------------- /exo/inference/mlx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/inference/mlx/__init__.py -------------------------------------------------------------------------------- /exo/inference/mlx/losses.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | import mlx.nn as nn 3 | def length_masked_ce_loss(model, inputs, targets, lengths): 4 | # Run model on inputs 5 | logits = model(inputs).astype(mx.float32) 6 | 7 | # Mask padding tokens 8 | length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] 9 | 10 | # Calculate the loss 11 | ce = nn.losses.cross_entropy(logits, targets) * length_mask 12 | loss = ce.sum() / length_mask.sum() 13 | # print(f"| {inputs=}\n| ==>{logits=}\n| ~^~{ce=}\n| == {loss=}") 14 | return loss 15 | 16 | #Naive intermediate layer loss, where we replace the targets with gradients and just multiply the output by the gradients to derive the loss. This is naive and may warrant some further iteration, but will do the job for now 17 | def back_gradient_loss(model, inputs, gradients, lengths): 18 | out = model(inputs).astype(mx.float32) 19 | grad = gradients.astype(mx.float32) 20 | 21 | # Mask padding tokens 22 | length_mask = mx.repeat(mx.arange(inputs.shape[1])[None, :] < lengths[:, None], out.shape[-1]).reshape(out.shape) 23 | 24 | masked_sum = (out * length_mask).sum(axis=1) 25 | gradient_lens = mx.abs(grad * masked_sum) 26 | loss = gradient_lens.sum() / length_mask.sum() 27 | # print(f"| {inputs=}\n" 28 | # + f"| ==>{out=}\n" 29 | # + f"| ~^~{masked_sum=}\n" 30 | # + f"| <~>{gradient_lens=}\n" 31 | # + f"| == {loss=}") 32 | return loss 33 | 34 | loss_fns = { 35 | "back_gradient": back_gradient_loss, 36 | "length_masked_ce": length_masked_ce_loss, 37 | } 38 | -------------------------------------------------------------------------------- /exo/inference/mlx/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/inference/mlx/models/__init__.py -------------------------------------------------------------------------------- /exo/inference/mlx/models/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import mlx.core as mx 3 | import mlx.nn as nn 4 | from mlx_lm.models.cache import KVCache 5 | 6 | 7 | class IdentityBlock(nn.Module): 8 | def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array: 9 | return x 10 | -------------------------------------------------------------------------------- /exo/inference/mlx/models/deepseek_v2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | 7 | from mlx_lm.models.cache import KVCache 8 | from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer 9 | from .base import IdentityBlock 10 | from exo.inference.shard import Shard 11 | 12 | 13 | @dataclass 14 | class ModelArgs(ModelArgs): 15 | shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) 16 | 17 | def __post_init__(self): 18 | if isinstance(self.shard, Shard): 19 | return 20 | if not isinstance(self.shard, dict): 21 | raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") 22 | 23 | self.shard = Shard(**self.shard) 24 | 25 | 26 | class DeepseekV2Model(nn.Module): 27 | def __init__(self, config: ModelArgs): 28 | super().__init__() 29 | self.args = config 30 | self.num_hidden_layers = config.num_hidden_layers 31 | self.vocab_size = config.vocab_size 32 | if self.args.shard.is_first_layer(): 33 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) 34 | 35 | self.layers = [] 36 | for i in range(self.num_hidden_layers): 37 | if self.args.shard.start_layer <= i <= self.args.shard.end_layer: 38 | self.layers.append(DeepseekV2DecoderLayer(config, i)) 39 | else: 40 | self.layers.append(IdentityBlock()) 41 | 42 | if self.args.shard.is_last_layer(): 43 | self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 44 | 45 | def __call__( 46 | self, 47 | x: mx.array, 48 | cache: Optional[KVCache] = None, 49 | ) -> mx.array: 50 | if self.args.shard.is_first_layer(): 51 | h = self.embed_tokens(x) 52 | else: 53 | h = x 54 | 55 | mask = None 56 | T = h.shape[1] 57 | if T > 1: 58 | mask = nn.MultiHeadAttention.create_additive_causal_mask(T) 59 | mask = mask.astype(h.dtype) 60 | 61 | if cache is None: 62 | cache = [None]*len(self.layers) 63 | 64 | for layer, c in zip(self.layers, cache): 65 | h = layer(h, mask, c) 66 | 67 | if self.args.shard.is_last_layer(): 68 | h = self.norm(h) 69 | return h 70 | 71 | 72 | class Model(nn.Module): 73 | def __init__(self, config: ModelArgs): 74 | super().__init__() 75 | self.args = config 76 | self.model_type = config.model_type 77 | self.model = DeepseekV2Model(config) 78 | if self.args.shard.is_last_layer(): 79 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 80 | 81 | def __call__( 82 | self, 83 | inputs: mx.array, 84 | cache: Optional[KVCache] = None, 85 | ): 86 | out = self.model(inputs, cache) 87 | if self.args.shard.is_last_layer(): 88 | return self.lm_head(out) 89 | return out 90 | 91 | def sanitize(self, weights): 92 | shard_state_dict = {} 93 | 94 | for key, value in weights.items(): 95 | if key.startswith('model.layers.'): 96 | layer_num = int(key.split('.')[2]) 97 | if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer: 98 | shard_state_dict[key] = value 99 | elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'): 100 | shard_state_dict[key] = value 101 | elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')): 102 | shard_state_dict[key] = value 103 | 104 | for l in range(self.args.num_hidden_layers): 105 | prefix = f"model.layers.{l}" 106 | for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: 107 | for k in ["weight", "scales", "biases"]: 108 | if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict: 109 | to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)] 110 | shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) 111 | 112 | return shard_state_dict 113 | 114 | @property 115 | def layers(self): 116 | return self.model.layers 117 | 118 | @property 119 | def head_dim(self): 120 | return ( 121 | self.args.qk_nope_head_dim + self.args.qk_rope_head_dim, 122 | self.args.v_head_dim, 123 | ) 124 | 125 | @property 126 | def n_kv_heads(self): 127 | return self.args.num_key_value_heads 128 | -------------------------------------------------------------------------------- /exo/inference/mlx/models/deepseek_v3.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | import mlx.core as mx 5 | import mlx.nn as nn 6 | 7 | from mlx_lm.models.cache import KVCache 8 | from mlx_lm.models.deepseek_v3 import ( 9 | ModelArgs as V3ModelArgs, 10 | DeepseekV3DecoderLayer, 11 | ) 12 | from .base import IdentityBlock 13 | from exo.inference.shard import Shard 14 | 15 | 16 | @dataclass 17 | class ModelArgs(V3ModelArgs): 18 | shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) 19 | 20 | def __post_init__(self): 21 | if isinstance(self.shard, Shard): 22 | return 23 | if not isinstance(self.shard, dict): 24 | raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") 25 | 26 | self.shard = Shard(**self.shard) 27 | 28 | 29 | class DeepseekV3Model(nn.Module): 30 | def __init__(self, config: ModelArgs): 31 | super().__init__() 32 | self.args = config 33 | self.num_hidden_layers = config.num_hidden_layers 34 | self.vocab_size = config.vocab_size 35 | if self.args.shard.is_first_layer(): 36 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) 37 | 38 | self.layers = [] 39 | for i in range(self.num_hidden_layers): 40 | if self.args.shard.start_layer <= i <= self.args.shard.end_layer: 41 | self.layers.append(DeepseekV3DecoderLayer(config, i)) 42 | else: 43 | self.layers.append(IdentityBlock()) 44 | 45 | if self.args.shard.is_last_layer(): 46 | self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 47 | 48 | def __call__( 49 | self, 50 | x: mx.array, 51 | cache: Optional[KVCache] = None, 52 | ) -> mx.array: 53 | if self.args.shard.is_first_layer(): 54 | h = self.embed_tokens(x) 55 | else: 56 | h = x 57 | 58 | mask = None 59 | T = h.shape[1] 60 | if T > 1: 61 | mask = nn.MultiHeadAttention.create_additive_causal_mask(T) 62 | mask = mask.astype(h.dtype) 63 | 64 | if cache is None: 65 | cache = [None]*len(self.layers) 66 | 67 | for layer, c in zip(self.layers, cache): 68 | h = layer(h, mask, c) 69 | 70 | if self.args.shard.is_last_layer(): 71 | h = self.norm(h) 72 | return h 73 | 74 | 75 | class Model(nn.Module): 76 | def __init__(self, config: ModelArgs): 77 | super().__init__() 78 | self.args = config 79 | self.model_type = config.model_type 80 | self.model = DeepseekV3Model(config) 81 | if self.args.shard.is_last_layer(): 82 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 83 | 84 | def __call__( 85 | self, 86 | inputs: mx.array, 87 | cache: Optional[KVCache] = None, 88 | ): 89 | out = self.model(inputs, cache) 90 | if self.args.shard.is_last_layer(): 91 | return self.lm_head(out) 92 | return out 93 | 94 | def sanitize(self, weights): 95 | shard_state_dict = {} 96 | 97 | for key, value in weights.items(): 98 | if key.startswith('model.layers.'): 99 | layer_num = int(key.split('.')[2]) 100 | if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer: 101 | shard_state_dict[key] = value 102 | elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'): 103 | shard_state_dict[key] = value 104 | elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')): 105 | shard_state_dict[key] = value 106 | 107 | for l in range(self.args.num_hidden_layers): 108 | prefix = f"model.layers.{l}" 109 | for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: 110 | for k in ["weight", "scales", "biases"]: 111 | expert_key = f"{prefix}.mlp.experts.0.{m}.{k}" 112 | if expert_key in shard_state_dict: 113 | to_join = [ 114 | shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") 115 | for e in range(self.args.n_routed_experts) 116 | ] 117 | shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) 118 | 119 | return shard_state_dict 120 | 121 | @property 122 | def layers(self): 123 | return self.model.layers 124 | 125 | @property 126 | def head_dim(self): 127 | return ( 128 | self.args.qk_nope_head_dim + self.args.qk_rope_head_dim, 129 | self.args.v_head_dim, 130 | ) 131 | 132 | @property 133 | def n_kv_heads(self): 134 | return self.args.num_key_value_heads 135 | -------------------------------------------------------------------------------- /exo/inference/mlx/models/gemma2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | 6 | from mlx_lm.models.base import create_attention_mask 7 | from mlx_lm.models.gemma2 import TransformerBlock, ModelArgs, RMSNorm 8 | 9 | from ...shard import Shard 10 | from .base import IdentityBlock 11 | 12 | 13 | @dataclass 14 | class ModelArgs(ModelArgs): 15 | shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) 16 | 17 | def __post_init__(self): 18 | if isinstance(self.shard, Shard): 19 | return 20 | if not isinstance(self.shard, dict): 21 | raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") 22 | 23 | self.shard = Shard(**self.shard) 24 | 25 | 26 | class GemmaModel(nn.Module): 27 | def __init__(self, args: ModelArgs): 28 | super().__init__() 29 | self.args = args 30 | self.vocab_size = args.vocab_size 31 | self.num_hidden_layers = args.num_hidden_layers 32 | assert self.vocab_size > 0 33 | if args.shard.is_first_layer() or args.shard.is_last_layer(): 34 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 35 | self.layers = [] 36 | for i in range(self.num_hidden_layers): 37 | if args.shard.start_layer <= i <= args.shard.end_layer: 38 | self.layers.append(TransformerBlock(args=args)) 39 | else: 40 | self.layers.append(IdentityBlock()) 41 | if args.shard.is_last_layer(): 42 | self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 43 | 44 | def __call__( 45 | self, 46 | inputs: mx.array, 47 | cache=None, 48 | ): 49 | if self.args.shard.is_first_layer(): 50 | h = self.embed_tokens(inputs) 51 | h = h * (self.args.hidden_size**0.5) 52 | else: 53 | h = inputs 54 | 55 | mask = None 56 | if h.ndim > 1 and h.shape[1] > 1: 57 | mask = create_attention_mask(h, cache) 58 | 59 | if cache is None: 60 | cache = [None]*len(self.layers) 61 | 62 | for layer, c in zip(self.layers, cache): 63 | h = layer(h, mask, cache=c) 64 | 65 | if self.args.shard.is_last_layer(): 66 | h = self.norm(h) 67 | return h 68 | 69 | 70 | class Model(nn.Module): 71 | def __init__(self, args: ModelArgs): 72 | super().__init__() 73 | self.args = args 74 | self.model_type = args.model_type 75 | self.model = GemmaModel(args) 76 | if args.shard.is_last_layer(): 77 | self.final_logit_softcapping = args.final_logit_softcapping 78 | 79 | def __call__( 80 | self, 81 | inputs: mx.array, 82 | cache=None, 83 | ): 84 | out = self.model(inputs, cache) 85 | if self.args.shard.is_last_layer(): 86 | out = self.model.embed_tokens.as_linear(out) 87 | out = mx.tanh(out / self.final_logit_softcapping) 88 | out = out * self.final_logit_softcapping 89 | return out 90 | 91 | def sanitize(self, weights): 92 | shard_state_dict = {} 93 | 94 | for key, value in weights.items(): 95 | if "self_attn.rotary_emb.inv_freq" in key: 96 | continue 97 | if key.startswith('model.layers.'): 98 | layer_num = int(key.split('.')[2]) 99 | if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer: 100 | shard_state_dict[key] = value 101 | elif (self.args.shard.is_first_layer() or self.args.shard.is_last_layer()) and key.startswith('model.embed_tokens'): 102 | shard_state_dict[key] = value 103 | elif self.args.shard.is_last_layer() and (key.startswith('model.norm')): 104 | shard_state_dict[key] = value 105 | 106 | return shard_state_dict 107 | 108 | @property 109 | def layers(self): 110 | return self.model.layers 111 | 112 | @property 113 | def head_dim(self): 114 | return self.args.head_dim 115 | 116 | @property 117 | def n_kv_heads(self): 118 | return self.args.num_key_value_heads 119 | -------------------------------------------------------------------------------- /exo/inference/mlx/models/llama.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | 6 | from mlx_lm.models.base import create_attention_mask 7 | from mlx_lm.models.llama import TransformerBlock, ModelArgs 8 | 9 | from ...shard import Shard 10 | from .base import IdentityBlock 11 | 12 | 13 | @dataclass 14 | class ModelArgs(ModelArgs): 15 | shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) 16 | 17 | def __post_init__(self): 18 | super().__post_init__() # Ensure parent initializations are respected 19 | 20 | if isinstance(self.shard, Shard): 21 | return 22 | if not isinstance(self.shard, dict): 23 | raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") 24 | 25 | self.shard = Shard(**self.shard) 26 | 27 | 28 | class LlamaModel(nn.Module): 29 | def __init__(self, args: ModelArgs): 30 | super().__init__() 31 | self.args = args 32 | self.vocab_size = args.vocab_size 33 | self.num_hidden_layers = args.num_hidden_layers 34 | assert self.vocab_size > 0 35 | if args.shard.is_first_layer() or (args.shard.is_last_layer() and args.tie_word_embeddings): 36 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 37 | self.layers = [] 38 | for i in range(self.num_hidden_layers): 39 | if args.shard.start_layer <= i <= args.shard.end_layer: 40 | self.layers.append(TransformerBlock(args=args)) 41 | else: 42 | self.layers.append(IdentityBlock()) 43 | if args.shard.is_last_layer(): 44 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 45 | 46 | def __call__( 47 | self, 48 | inputs: mx.array, 49 | cache=None, 50 | ): 51 | if self.args.shard.is_first_layer(): 52 | h = self.embed_tokens(inputs) 53 | else: 54 | h = inputs 55 | 56 | mask = None 57 | if h.ndim > 1 and h.shape[1] > 1: 58 | mask = create_attention_mask(h, cache) 59 | 60 | if cache is None: 61 | cache = [None]*len(self.layers) 62 | 63 | for layer, c in zip(self.layers, cache): 64 | h = layer(h, mask, cache=c) 65 | 66 | if self.args.shard.is_last_layer(): 67 | h = self.norm(h) 68 | return h 69 | 70 | 71 | class Model(nn.Module): 72 | def __init__(self, args: ModelArgs): 73 | super().__init__() 74 | self.args = args 75 | self.model_type = args.model_type 76 | self.model = LlamaModel(args) 77 | if args.shard.is_last_layer(): 78 | if not args.tie_word_embeddings: 79 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 80 | 81 | def __call__( 82 | self, 83 | inputs: mx.array, 84 | cache=None, 85 | ): 86 | out = self.model(inputs, cache) 87 | if self.args.shard.is_last_layer(): 88 | if self.args.tie_word_embeddings: 89 | out = self.model.embed_tokens.as_linear(out) 90 | else: 91 | out = self.lm_head(out) 92 | return out 93 | 94 | def sanitize(self, weights): 95 | shard_state_dict = {} 96 | 97 | for key, value in weights.items(): 98 | if "self_attn.rotary_emb.inv_freq" in key: 99 | continue 100 | if key.startswith('model.layers.'): 101 | layer_num = int(key.split('.')[2]) 102 | if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer: 103 | shard_state_dict[key] = value 104 | elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'): 105 | shard_state_dict[key] = value 106 | elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'): 107 | shard_state_dict[key] = value 108 | elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'): 109 | shard_state_dict[key] = value 110 | elif self.args.shard.is_last_layer() and (key.startswith('model.norm')): 111 | shard_state_dict[key] = value 112 | 113 | return shard_state_dict 114 | 115 | @property 116 | def layers(self): 117 | return self.model.layers 118 | 119 | @property 120 | def head_dim(self): 121 | return (self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads) 122 | 123 | @property 124 | def n_kv_heads(self): 125 | return self.args.num_key_value_heads 126 | -------------------------------------------------------------------------------- /exo/inference/mlx/models/phi3.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | 6 | from mlx_lm.models.base import create_attention_mask 7 | from mlx_lm.models.phi3 import TransformerBlock, ModelArgs 8 | 9 | from ...shard import Shard 10 | from .base import IdentityBlock 11 | 12 | @dataclass 13 | class ModelArgs(ModelArgs): 14 | shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) 15 | 16 | def __post_init__(self): 17 | super().__post_init__() 18 | 19 | if isinstance(self.shard, Shard): 20 | return 21 | if not isinstance(self.shard, dict): 22 | raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") 23 | 24 | self.shard = Shard(**self.shard) 25 | 26 | class Phi3Model(nn.Module): 27 | def __init__(self, args: ModelArgs): 28 | super().__init__() 29 | self.args = args 30 | self.vocab_size = args.vocab_size 31 | self.num_hidden_layers = args.num_hidden_layers 32 | assert self.vocab_size > 0 33 | 34 | if self.args.shard.is_first_layer(): 35 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 36 | 37 | self.layers = [] 38 | for i in range(self.num_hidden_layers): 39 | if self.args.shard.start_layer <= i <= self.args.shard.end_layer: 40 | self.layers.append(TransformerBlock(args=args)) 41 | else: 42 | self.layers.append(IdentityBlock()) 43 | 44 | if self.args.shard.is_last_layer(): 45 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 46 | 47 | def __call__( 48 | self, 49 | inputs: mx.array, 50 | cache=None, 51 | ): 52 | if self.args.shard.is_first_layer(): 53 | h = self.embed_tokens(inputs) 54 | else: 55 | h = inputs 56 | 57 | mask = None 58 | if h.shape[1] > 1: 59 | mask = create_attention_mask(h, cache) 60 | 61 | if cache is None: 62 | cache = [None] * len(self.layers) 63 | 64 | for layer, c in zip(self.layers, cache): 65 | h = layer(h, mask, c) 66 | 67 | if self.args.shard.is_last_layer(): 68 | h = self.norm(h) 69 | return h 70 | 71 | class Model(nn.Module): 72 | def __init__(self, args: ModelArgs): 73 | super().__init__() 74 | self.args = args 75 | self.model_type = args.model_type 76 | self.model = Phi3Model(args) 77 | if self.args.shard.is_last_layer(): 78 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 79 | 80 | def __call__( 81 | self, 82 | inputs: mx.array, 83 | cache=None, 84 | ): 85 | out = self.model(inputs, cache) 86 | if self.args.shard.is_last_layer(): 87 | out = self.lm_head(out) 88 | return out 89 | 90 | def sanitize(self, weights): 91 | shard_state_dict = {} 92 | 93 | for key, value in weights.items(): 94 | if "self_attn.rope.inv_freq" in key: 95 | continue 96 | if key.startswith('model.layers.'): 97 | layer_num = int(key.split('.')[2]) 98 | if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer: 99 | shard_state_dict[key] = value 100 | elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'): 101 | shard_state_dict[key] = value 102 | elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')): 103 | shard_state_dict[key] = value 104 | 105 | return shard_state_dict 106 | 107 | @property 108 | def layers(self): 109 | return self.model.layers 110 | 111 | @property 112 | def head_dim(self): 113 | return self.args.hidden_size // self.args.num_attention_heads 114 | 115 | @property 116 | def n_kv_heads(self): 117 | return self.args.num_key_value_heads 118 | -------------------------------------------------------------------------------- /exo/inference/mlx/models/qwen2.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import mlx.core as mx 4 | import mlx.nn as nn 5 | 6 | from mlx_lm.models.base import create_attention_mask 7 | from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs 8 | 9 | from ...shard import Shard 10 | from .base import IdentityBlock 11 | 12 | @dataclass 13 | class ModelArgs(ModelArgs): 14 | shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0)) 15 | 16 | def __post_init__(self): 17 | super().__post_init__() 18 | 19 | if isinstance(self.shard, Shard): 20 | return 21 | if not isinstance(self.shard, dict): 22 | raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead") 23 | 24 | self.shard = Shard(**self.shard) 25 | 26 | class Qwen2Model(nn.Module): 27 | def __init__(self, args: ModelArgs): 28 | super().__init__() 29 | self.args = args 30 | self.vocab_size = args.vocab_size 31 | self.num_hidden_layers = args.num_hidden_layers 32 | assert self.vocab_size > 0 33 | 34 | if self.args.shard.is_first_layer() or (self.args.shard.is_last_layer() and args.tie_word_embeddings): 35 | self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) 36 | 37 | self.layers = [] 38 | for i in range(self.num_hidden_layers): 39 | if self.args.shard.start_layer <= i <= self.args.shard.end_layer: 40 | self.layers.append(TransformerBlock(args=args)) 41 | else: 42 | self.layers.append(IdentityBlock()) 43 | 44 | if self.args.shard.is_last_layer(): 45 | self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) 46 | 47 | def __call__( 48 | self, 49 | inputs: mx.array, 50 | cache=None, 51 | ): 52 | if self.args.shard.is_first_layer(): 53 | h = self.embed_tokens(inputs) 54 | else: 55 | h = inputs 56 | 57 | mask = None 58 | if h.shape[1] > 1: 59 | mask = create_attention_mask(h, cache) 60 | 61 | if cache is None: 62 | cache = [None]*len(self.layers) 63 | 64 | for layer, c in zip(self.layers, cache): 65 | h = layer(h, mask, c) 66 | 67 | if self.args.shard.is_last_layer(): 68 | h = self.norm(h) 69 | return h 70 | 71 | 72 | class Model(nn.Module): 73 | def __init__(self, args: ModelArgs): 74 | super().__init__() 75 | self.args = args 76 | self.model_type = args.model_type 77 | self.model = Qwen2Model(args) 78 | if self.args.shard.is_last_layer(): 79 | if not args.tie_word_embeddings: 80 | self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) 81 | 82 | def __call__( 83 | self, 84 | inputs: mx.array, 85 | cache=None, 86 | ): 87 | out = self.model(inputs, cache) 88 | if self.args.shard.is_last_layer(): 89 | if self.args.tie_word_embeddings: 90 | out = self.model.embed_tokens.as_linear(out) 91 | else: 92 | out = self.lm_head(out) 93 | return out 94 | 95 | def sanitize(self, weights): 96 | shard_state_dict = {} 97 | 98 | for key, value in weights.items(): 99 | if "self_attn.rotary_emb.inv_freq" in key: 100 | continue 101 | if key.startswith('model.layers.'): 102 | layer_num = int(key.split('.')[2]) 103 | if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer: 104 | shard_state_dict[key] = value 105 | elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'): 106 | shard_state_dict[key] = value 107 | elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'): 108 | shard_state_dict[key] = value 109 | elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'): 110 | shard_state_dict[key] = value 111 | elif self.args.shard.is_last_layer() and (key.startswith('model.norm')): 112 | shard_state_dict[key] = value 113 | 114 | if self.args.tie_word_embeddings: 115 | shard_state_dict.pop("lm_head.weight", None) 116 | 117 | return shard_state_dict 118 | 119 | @property 120 | def layers(self): 121 | return self.model.layers 122 | 123 | @property 124 | def head_dim(self): 125 | return self.args.hidden_size // self.args.num_attention_heads 126 | 127 | @property 128 | def n_kv_heads(self): 129 | return self.args.num_key_value_heads 130 | -------------------------------------------------------------------------------- /exo/inference/mlx/models/sd_models/tokenizer.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py 2 | 3 | import regex 4 | import json 5 | import glob 6 | 7 | 8 | class Tokenizer: 9 | """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" 10 | 11 | def __init__(self, bpe_ranks, vocab): 12 | self.bpe_ranks = bpe_ranks 13 | self.vocab = vocab 14 | self.pat = regex.compile( 15 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 16 | regex.IGNORECASE, 17 | ) 18 | 19 | self._cache = {self.bos: self.bos, self.eos: self.eos} 20 | 21 | @property 22 | def bos(self): 23 | return "<|startoftext|>" 24 | 25 | @property 26 | def bos_token(self): 27 | return self.vocab[self.bos] 28 | 29 | @property 30 | def eos(self): 31 | return "<|endoftext|>" 32 | 33 | @property 34 | def eos_token(self): 35 | return self.vocab[self.eos] 36 | 37 | def bpe(self, text): 38 | if text in self._cache: 39 | return self._cache[text] 40 | 41 | unigrams = list(text[:-1]) + [text[-1] + ""] 42 | unique_bigrams = set(zip(unigrams, unigrams[1:])) 43 | 44 | if not unique_bigrams: 45 | return unigrams 46 | 47 | # In every iteration try to merge the two most likely bigrams. If none 48 | # was merged we are done. 49 | # 50 | # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py 51 | while unique_bigrams: 52 | bigram = min( 53 | unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")) 54 | ) 55 | if bigram not in self.bpe_ranks: 56 | break 57 | 58 | new_unigrams = [] 59 | skip = False 60 | for a, b in zip(unigrams, unigrams[1:]): 61 | if skip: 62 | skip = False 63 | continue 64 | 65 | if (a, b) == bigram: 66 | new_unigrams.append(a + b) 67 | skip = True 68 | 69 | else: 70 | new_unigrams.append(a) 71 | 72 | if not skip: 73 | new_unigrams.append(b) 74 | 75 | unigrams = new_unigrams 76 | unique_bigrams = set(zip(unigrams, unigrams[1:])) 77 | 78 | self._cache[text] = unigrams 79 | 80 | return unigrams 81 | 82 | def tokenize(self, text, prepend_bos=True, append_eos=True): 83 | if isinstance(text, list): 84 | return [self.tokenize(t, prepend_bos, append_eos) for t in text] 85 | 86 | # Lower case cleanup and split according to self.pat. Hugging Face does 87 | # a much more thorough job here but this should suffice for 95% of 88 | # cases. 89 | clean_text = regex.sub(r"\s+", " ", text.lower()) 90 | tokens = regex.findall(self.pat, clean_text) 91 | 92 | # Split the tokens according to the byte-pair merge file 93 | bpe_tokens = [ti for t in tokens for ti in self.bpe(t)] 94 | 95 | # Map to token ids and return 96 | tokens = [self.vocab[t] for t in bpe_tokens] 97 | if prepend_bos: 98 | tokens = [self.bos_token] + tokens 99 | if append_eos: 100 | tokens.append(self.eos_token) 101 | 102 | return tokens 103 | 104 | def encode(self, prompt): 105 | tokens = [self.tokenize(prompt)] 106 | negative_text = "" 107 | if negative_text is not None: 108 | tokens += [self.tokenize(negative_text)] 109 | lengths = [len(t) for t in tokens] 110 | N = max(lengths) 111 | tokens = [t + [0] * (N - len(t)) for t in tokens] 112 | return tokens 113 | 114 | def load_tokenizer( 115 | model_path: str, 116 | vocab_key: str = "tokenizer_vocab", 117 | merges_key: str = "tokenizer_merges", 118 | ): 119 | 120 | vocab_file = glob.glob(str(model_path/"tokenizer"/vocab_key))[0] 121 | with open(vocab_file, encoding="utf-8") as f: 122 | vocab = json.load(f) 123 | 124 | merges_file = glob.glob(str(model_path/"tokenizer"/merges_key))[0] 125 | with open(merges_file, encoding="utf-8") as f: 126 | bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1] 127 | bpe_merges = [tuple(m.split()) for m in bpe_merges] 128 | bpe_ranks = dict(map(reversed, enumerate(bpe_merges))) 129 | 130 | return Tokenizer(bpe_ranks, vocab) 131 | 132 | -------------------------------------------------------------------------------- /exo/inference/mlx/perf_improvements.md: -------------------------------------------------------------------------------- 1 | # Perf improvements 2 | 3 | Target: 460 tok/sec 4 | - removing sample goes from 369 -> 402 5 | - performance degrades as we generate more tokens 6 | - make mlx inference engien synchronous, removing thread pool executor: 402 -> 413 7 | - remove self.on_opaque_status.trigger_all: 413 -> 418 8 | -------------------------------------------------------------------------------- /exo/inference/mlx/test_non_blocking.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | import numpy as np 4 | from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine 5 | from exo.download.new_shard_download import NewShardDownloader 6 | from exo.inference.shard import Shard 7 | from exo.models import build_base_shard 8 | from collections import deque 9 | from statistics import mean, median 10 | 11 | async def test_non_blocking(): 12 | # Setup 13 | shard_downloader = NewShardDownloader() 14 | engine = MLXDynamicShardInferenceEngine(shard_downloader) 15 | _shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine") 16 | shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers) 17 | await engine.ensure_shard(shard) 18 | 19 | queue = asyncio.Queue() 20 | measurements = deque(maxlen=1000000) 21 | running = True 22 | 23 | async def mlx_worker(): 24 | try: 25 | start_time = time.time() 26 | count = 0 27 | while running and (time.time() - start_time) < 5: # Hard time limit 28 | start = time.perf_counter_ns() 29 | await engine.infer_prompt("req1", shard, "test prompt") 30 | duration = (time.perf_counter_ns() - start) / 1_000_000 # Convert to ms 31 | count += 1 32 | print(f"MLX operation {count} took: {duration:.3f}ms") 33 | except asyncio.CancelledError: 34 | pass 35 | finally: 36 | print(f"\nTotal MLX operations completed: {count}") 37 | print(f"Average rate: {count/5:.1f} ops/second") 38 | 39 | async def latency_producer(): 40 | try: 41 | start_time = time.perf_counter_ns() 42 | count = 0 43 | while running: 44 | await queue.put(time.perf_counter_ns()) 45 | count += 1 46 | await asyncio.sleep(0) # Yield to event loop without delay 47 | duration = (time.perf_counter_ns() - start_time) / 1e9 # Convert to seconds 48 | print(f"\nProducer iterations: {count}") 49 | print(f"Producer rate: {count/duration:.1f} iterations/second") 50 | except asyncio.CancelledError: 51 | pass 52 | 53 | async def latency_consumer(): 54 | try: 55 | while running: 56 | timestamp = await queue.get() 57 | latency = (time.perf_counter_ns() - timestamp) / 1_000_000 # Convert to ms 58 | measurements.append(latency) 59 | queue.task_done() 60 | except asyncio.CancelledError: 61 | pass 62 | 63 | tasks = [ 64 | asyncio.create_task(mlx_worker()), 65 | asyncio.create_task(latency_producer()), 66 | asyncio.create_task(latency_consumer()) 67 | ] 68 | 69 | try: 70 | await asyncio.wait_for(asyncio.gather(*tasks), timeout=6) 71 | except asyncio.TimeoutError: 72 | print("\nTest timed out") 73 | finally: 74 | running = False 75 | for task in tasks: 76 | task.cancel() 77 | await asyncio.gather(*tasks, return_exceptions=True) 78 | print(f"\nFinal measurement count: {len(measurements)}") 79 | 80 | if __name__ == "__main__": 81 | asyncio.run(test_non_blocking()) 82 | -------------------------------------------------------------------------------- /exo/inference/mlx/test_sharded_model.py: -------------------------------------------------------------------------------- 1 | from exo.inference.shard import Shard 2 | import mlx.core as mx 3 | import mlx.nn as nn 4 | from typing import Optional 5 | import numpy as np 6 | 7 | 8 | class DummyModel(nn.Module): 9 | def __init__(self, shard: Optional[Shard] = None): 10 | self.shard = shard 11 | self.layers = [ 12 | nn.Linear(8, 128), 13 | nn.Linear(128, 128), 14 | nn.Linear(128, 128), 15 | nn.Linear(128, 128), 16 | nn.Linear(128, 8), 17 | ] 18 | 19 | self.n_kv_heads = 4 20 | self.head_dim = 4 21 | 22 | def __call__(self, x, cache=None): 23 | if self.shard: 24 | for layer in self.layers[self.shard.start_layer:self.shard.end_layer + 1]: 25 | x = layer(x) 26 | if self.shard.is_last_layer(): 27 | x = x.reshape((1, 2, 4)) 28 | else: 29 | for layer in self.layers: 30 | x = layer(x) 31 | x = x.reshape((1, 2, 4)) 32 | 33 | return x 34 | 35 | 36 | model = DummyModel() 37 | model.save_weights("./test_weights.npz") 38 | n_layers = 5 39 | shard1 = Shard("test", 0, n_layers // 2, n_layers) 40 | sharded_model1 = DummyModel(shard1) 41 | shard2 = Shard("test", n_layers//2 + 1, n_layers - 1, n_layers) 42 | sharded_model2 = DummyModel(shard2) 43 | 44 | model.load_weights("./test_weights.npz") 45 | sharded_model1.load_weights("./test_weights.npz") 46 | sharded_model2.load_weights("./test_weights.npz") 47 | 48 | fullresp = model(mx.array([1, 2, 3, 4, 5, 6, 7, 8])) 49 | resp1 = sharded_model1(mx.array([1, 2, 3, 4, 5, 6, 7, 8])) 50 | resp2 = sharded_model2(resp1) 51 | 52 | assert np.all(np.array(fullresp) == np.array(resp2)) 53 | -------------------------------------------------------------------------------- /exo/inference/shard.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | 4 | @dataclass(frozen=True) 5 | class Shard: 6 | model_id: str 7 | start_layer: int 8 | end_layer: int 9 | n_layers: int 10 | 11 | def __hash__(self): 12 | return hash((self.model_id, self.start_layer, self.end_layer, self.n_layers)) 13 | 14 | def is_first_layer(self) -> bool: 15 | return self.start_layer == 0 16 | 17 | def is_last_layer(self) -> bool: 18 | return self.end_layer == self.n_layers - 1 19 | 20 | def get_layer_count(self) -> int: 21 | return self.end_layer - self.start_layer + 1 22 | 23 | def to_dict(self) -> dict: 24 | return { 25 | "model_id": self.model_id, 26 | "start_layer": self.start_layer, 27 | "end_layer": self.end_layer, 28 | "n_layers": self.n_layers, 29 | } 30 | 31 | def from_dict(data: dict) -> 'Shard': 32 | return Shard(**data) 33 | 34 | def overlaps(self, other: 'Shard') -> bool: 35 | return shards_overlap(self, other) 36 | 37 | 38 | def shards_overlap(shard1: Shard, shard2: Shard) -> bool: 39 | return (shard1.model_id == shard2.model_id and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer)) 40 | -------------------------------------------------------------------------------- /exo/inference/test_dummy_inference_engine.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from exo.inference.dummy_inference_engine import DummyInferenceEngine 4 | from exo.inference.shard import Shard 5 | 6 | 7 | @pytest.mark.asyncio 8 | async def test_dummy_inference_specific(): 9 | engine = DummyInferenceEngine() 10 | test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1) 11 | test_prompt = "This is a test prompt" 12 | 13 | result, _ = await engine.infer_prompt("test_request", test_shard, test_prompt) 14 | 15 | print(f"Inference result shape: {result.shape}") 16 | 17 | assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1" 18 | 19 | 20 | @pytest.mark.asyncio 21 | async def test_dummy_inference_engine(): 22 | # Initialize the DummyInferenceEngine 23 | engine = DummyInferenceEngine() 24 | 25 | # Create a test shard 26 | shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1) 27 | 28 | # Test infer_prompt 29 | output, _ = await engine.infer_prompt("test_id", shard, "Test prompt") 30 | 31 | assert isinstance(output, np.ndarray), "Output should be a numpy array" 32 | assert output.ndim == 2, "Output should be 2-dimensional" 33 | 34 | # Test infer_tensor 35 | input_tensor = np.array([[1, 2, 3]]) 36 | output, _ = await engine.infer_tensor("test_id", shard, input_tensor) 37 | 38 | assert isinstance(output, np.ndarray), "Output should be a numpy array" 39 | assert output.ndim == 2, "Output should be 2-dimensional" 40 | 41 | print("All tests passed!") 42 | 43 | 44 | if __name__ == "__main__": 45 | import asyncio 46 | asyncio.run(test_dummy_inference_engine()) 47 | asyncio.run(test_dummy_inference_specific()) 48 | -------------------------------------------------------------------------------- /exo/inference/test_inference_engine.py: -------------------------------------------------------------------------------- 1 | from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine 2 | from exo.inference.inference_engine import InferenceEngine 3 | from exo.download.new_shard_download import NewShardDownloader 4 | from exo.inference.shard import Shard 5 | from exo.helpers import DEBUG 6 | import os 7 | import asyncio 8 | import numpy as np 9 | 10 | 11 | # An inference engine should work the same for any number of Shards, as long as the Shards are continuous. 12 | async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int): 13 | prompt = "In a single word only, what is the last name of the current president of the USA?" 14 | resp_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt) 15 | token_full = await inference_engine_1.sample(resp_full) 16 | token_full = token_full.reshape(1, -1) 17 | next_resp_full, _ = await inference_engine_1.infer_tensor( 18 | "A", 19 | shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), 20 | input_data=token_full, 21 | ) 22 | 23 | pp = n_layers // 2 24 | resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt) 25 | resp2, _ = await inference_engine_2.infer_tensor( 26 | "B", 27 | shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers), 28 | input_data=resp1, 29 | ) 30 | tokens2 = await inference_engine_1.sample(resp2) 31 | tokens2 = tokens2.reshape(1, -1) 32 | resp3, _ = await inference_engine_1.infer_tensor( 33 | "B", 34 | shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), 35 | input_data=tokens2, 36 | ) 37 | resp4, _ = await inference_engine_2.infer_tensor( 38 | "B", 39 | shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers), 40 | input_data=resp3, 41 | ) 42 | 43 | assert np.array_equal(resp_full, resp2) 44 | assert np.array_equal(next_resp_full, resp4) 45 | 46 | 47 | asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(NewShardDownloader()), MLXDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 16)) 48 | 49 | if os.getenv("RUN_TINYGRAD", default="0") == "1": 50 | import tinygrad 51 | import os 52 | from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine 53 | tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) 54 | asyncio.run(test_inference_engine(TinygradDynamicShardInferenceEngine(NewShardDownloader()), TinygradDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 32)) 55 | -------------------------------------------------------------------------------- /exo/inference/tinygrad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/inference/tinygrad/__init__.py -------------------------------------------------------------------------------- /exo/inference/tinygrad/losses.py: -------------------------------------------------------------------------------- 1 | from tinygrad import Tensor, dtypes 2 | import numpy as np 3 | def length_masked_ce_loss(model, inputs, targets, lengths): 4 | # Run model on inputs 5 | logits = model(inputs).cast(dtypes.float32).contiguous() 6 | 7 | # Mask padding tokens 8 | length_mask = Tensor(np.arange(inputs.shape[1])[None, :] < lengths[:, None], requires_grad=False) 9 | 10 | # Calculate the loss 11 | ce = logits.sparse_categorical_crossentropy(Tensor(targets, requires_grad=False)).mul(length_mask) 12 | loss = ce.sum() / length_mask.sum() 13 | return loss 14 | 15 | -------------------------------------------------------------------------------- /exo/inference/tinygrad/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/inference/tinygrad/models/__init__.py -------------------------------------------------------------------------------- /exo/inference/tinygrad/stateful_model.py: -------------------------------------------------------------------------------- 1 | from tinygrad import Tensor, Variable 2 | from collections import OrderedDict 3 | from typing import List, Optional 4 | 5 | def create_kv_cache(x: Tensor, layer): 6 | cache_kv = Tensor.zeros(2, x.shape[0], layer.max_context, layer.n_kv_heads, layer.head_dim, dtype=x.dtype).contiguous().realize() 7 | if isinstance(x.device, tuple): 8 | # TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded 9 | cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize() 10 | return cache_kv.realize() 11 | 12 | class ModelState: 13 | cache: List[Tensor] 14 | start: int 15 | def __init__(self, cache: List[Tensor], start: int = 0): 16 | self.cache = cache 17 | self.start = start 18 | 19 | def make_prompt_state(x: Tensor, model): 20 | cache = [create_kv_cache(x, l.attention) for l in model.layers] 21 | 22 | return ModelState(cache) 23 | -------------------------------------------------------------------------------- /exo/inference/tinygrad/tinygrad_helpers.py: -------------------------------------------------------------------------------- 1 | from tinygrad.nn.state import safe_load, torch_load 2 | from tinygrad import Tensor 3 | from pathlib import Path 4 | import json 5 | from typing import List 6 | from exo.inference.shard import Shard 7 | from exo.helpers import DEBUG 8 | from exo.download.hf.hf_helpers import get_allow_patterns 9 | from fnmatch import fnmatch 10 | import re 11 | 12 | 13 | # **** helper functions **** 14 | def concat_weights(models, device=None): 15 | def convert(name) -> Tensor: 16 | disk_tensors: List[Tensor] = [model[name] for model in models] 17 | if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1: 18 | return disk_tensors[0].to(device=device) 19 | axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0 20 | lazy_tensors = [data.to(device=device) for data in disk_tensors] 21 | return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis) 22 | 23 | return {name: convert(name) for name in {name: None for model in models for name in model}} 24 | 25 | 26 | def load(fn: str, shard: Shard): 27 | if fn.endswith('.index.json'): 28 | with open(fn) as fp: 29 | weight_map = json.load(fp)['weight_map'] 30 | parts = {} 31 | filtered_weight_map = {} 32 | allow_patterns = get_allow_patterns(weight_map, shard) 33 | for k, n in weight_map.items(): 34 | if allow_patterns is not None and not any(fnmatch(n, r) for r in allow_patterns): 35 | continue 36 | if k.startswith("model.layers."): 37 | layer_num = int(k.split('.')[2]) 38 | if layer_num < shard.start_layer or layer_num > shard.end_layer: 39 | continue 40 | 41 | parts[n] = load(str(Path(fn).parent/Path(n).name), shard) 42 | filtered_weight_map[k] = n 43 | if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}") 44 | return {k: parts[n][k] for k, n in filtered_weight_map.items()} 45 | elif fn.endswith(".safetensors"): 46 | weight_map = safe_load(fn) 47 | for k in list(weight_map): 48 | if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer): 49 | del weight_map[k] 50 | return weight_map 51 | else: 52 | return torch_load(fn) 53 | -------------------------------------------------------------------------------- /exo/inference/tokenizers.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from os import PathLike 3 | from aiofiles import os as aios 4 | from typing import Union 5 | from transformers import AutoTokenizer, AutoProcessor 6 | import numpy as np 7 | from exo.helpers import DEBUG 8 | from exo.download.new_shard_download import ensure_downloads_dir 9 | 10 | 11 | class DummyTokenizer: 12 | def __init__(self): 13 | self.eos_token_id = 69 14 | self.vocab_size = 1000 15 | 16 | def apply_chat_template(self, conversation, tokenize=True, add_generation_prompt=True, tools=None, **kwargs): 17 | return "dummy_tokenized_prompt" 18 | 19 | def encode(self, text): 20 | return np.array([1]) 21 | 22 | def decode(self, tokens): 23 | return "dummy" * len(tokens) 24 | 25 | 26 | async def resolve_tokenizer(repo_id: Union[str, PathLike]): 27 | if repo_id == "dummy": 28 | return DummyTokenizer() 29 | local_path = await ensure_downloads_dir()/str(repo_id).replace("/", "--") 30 | if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}") 31 | try: 32 | if local_path and await aios.path.exists(local_path): 33 | if DEBUG >= 2: print(f"Resolving tokenizer for {repo_id=} from {local_path=}") 34 | return await _resolve_tokenizer(local_path) 35 | except: 36 | if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {repo_id=} normally...") 37 | if DEBUG >= 5: traceback.print_exc() 38 | return await _resolve_tokenizer(repo_id) 39 | 40 | 41 | async def _resolve_tokenizer(repo_id_or_local_path: Union[str, PathLike]): 42 | try: 43 | if DEBUG >= 4: print(f"Trying AutoProcessor for {repo_id_or_local_path}") 44 | processor = AutoProcessor.from_pretrained(repo_id_or_local_path, use_fast=True if "Mistral-Large" in f"{repo_id_or_local_path}" else False, trust_remote_code=True) 45 | if not hasattr(processor, 'eos_token_id'): 46 | processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id 47 | if not hasattr(processor, 'encode'): 48 | processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode 49 | if not hasattr(processor, 'decode'): 50 | processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode 51 | return processor 52 | except Exception as e: 53 | if DEBUG >= 4: print(f"Failed to load processor for {repo_id_or_local_path}. Error: {e}") 54 | if DEBUG >= 4: print(traceback.format_exc()) 55 | 56 | try: 57 | if DEBUG >= 4: print(f"Trying AutoTokenizer for {repo_id_or_local_path}") 58 | return AutoTokenizer.from_pretrained(repo_id_or_local_path, trust_remote_code=True) 59 | except Exception as e: 60 | if DEBUG >= 4: print(f"Failed to load tokenizer for {repo_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}") 61 | if DEBUG >= 4: print(traceback.format_exc()) 62 | 63 | raise ValueError(f"[TODO] Unsupported model: {repo_id_or_local_path}") 64 | -------------------------------------------------------------------------------- /exo/networking/__init__.py: -------------------------------------------------------------------------------- 1 | from .discovery import Discovery 2 | from .peer_handle import PeerHandle 3 | from .server import Server 4 | 5 | __all__ = ["Discovery", "PeerHandle", "Server"] 6 | -------------------------------------------------------------------------------- /exo/networking/discovery.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List 3 | from .peer_handle import PeerHandle 4 | 5 | 6 | class Discovery(ABC): 7 | @abstractmethod 8 | async def start(self) -> None: 9 | pass 10 | 11 | @abstractmethod 12 | async def stop(self) -> None: 13 | pass 14 | 15 | @abstractmethod 16 | async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: 17 | pass 18 | -------------------------------------------------------------------------------- /exo/networking/grpc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/networking/grpc/__init__.py -------------------------------------------------------------------------------- /exo/networking/grpc/node_service.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package node_service; 4 | 5 | service NodeService { 6 | rpc SendPrompt (PromptRequest) returns (Tensor) {} 7 | rpc SendTensor (TensorRequest) returns (Tensor) {} 8 | rpc SendExample (ExampleRequest) returns (Loss) {} 9 | rpc CollectTopology (CollectTopologyRequest) returns (Topology) {} 10 | rpc SendResult (SendResultRequest) returns (Empty) {} 11 | rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {} 12 | rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {} 13 | } 14 | 15 | message Shard { 16 | string model_id = 1; 17 | int32 start_layer = 2; 18 | int32 end_layer = 3; 19 | int32 n_layers = 4; 20 | } 21 | 22 | message PromptRequest { 23 | Shard shard = 1; 24 | string prompt = 2; 25 | optional string request_id = 3; 26 | optional InferenceState inference_state = 4; 27 | } 28 | 29 | message TensorRequest { 30 | Shard shard = 1; 31 | Tensor tensor = 2; 32 | optional string request_id = 3; 33 | optional InferenceState inference_state = 4; 34 | } 35 | 36 | message ExampleRequest { 37 | Shard shard = 1; 38 | Tensor example = 2; 39 | Tensor target = 3; 40 | Tensor length = 4; 41 | bool train = 5; 42 | optional string request_id = 6; 43 | } 44 | 45 | message Loss { 46 | float loss = 1; 47 | optional Tensor grads = 2; 48 | } 49 | 50 | message Tensor { 51 | bytes tensor_data = 1; 52 | repeated int32 shape = 2; 53 | string dtype = 3; 54 | } 55 | 56 | message TensorList { 57 | repeated Tensor tensors = 1; 58 | } 59 | 60 | message InferenceState { 61 | map tensor_data = 1; 62 | map tensor_list_data = 2; 63 | string other_data_json = 3; 64 | } 65 | 66 | message CollectTopologyRequest { 67 | repeated string visited = 1; 68 | int32 max_depth = 2; 69 | } 70 | 71 | message Topology { 72 | map nodes = 1; 73 | map peer_graph = 2; 74 | } 75 | 76 | message PeerConnection { 77 | string to_id = 1; 78 | optional string description = 2; 79 | } 80 | 81 | message PeerConnections { 82 | repeated PeerConnection connections = 1; 83 | } 84 | 85 | message DeviceFlops { 86 | double fp32 = 1; 87 | double fp16 = 2; 88 | double int8 = 3; 89 | } 90 | 91 | message DeviceCapabilities { 92 | string model = 1; 93 | string chip = 2; 94 | int32 memory = 3; 95 | DeviceFlops flops = 4; 96 | } 97 | 98 | message SendResultRequest { 99 | string request_id = 1; 100 | repeated int32 result = 2; 101 | optional Tensor tensor = 3; 102 | bool is_finished = 4; 103 | } 104 | 105 | message SendOpaqueStatusRequest { 106 | string request_id = 1; 107 | string status = 2; 108 | } 109 | 110 | message HealthCheckRequest {} 111 | 112 | message HealthCheckResponse { 113 | bool is_healthy = 1; 114 | } 115 | 116 | message Empty {} 117 | -------------------------------------------------------------------------------- /exo/networking/manual/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/networking/manual/__init__.py -------------------------------------------------------------------------------- /exo/networking/manual/manual_discovery.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | from typing import Dict, List, Callable, Optional 4 | from concurrent.futures import ThreadPoolExecutor 5 | 6 | from exo.networking.discovery import Discovery 7 | from exo.topology.device_capabilities import DeviceCapabilities 8 | from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig 9 | from exo.helpers import DEBUG_DISCOVERY 10 | from exo.networking.peer_handle import PeerHandle 11 | 12 | 13 | class ManualDiscovery(Discovery): 14 | def __init__( 15 | self, 16 | network_config_path: str, 17 | node_id: str, 18 | create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle], 19 | ): 20 | self.network_config_path = network_config_path 21 | self.node_id = node_id 22 | self.create_peer_handle = create_peer_handle 23 | 24 | self.listen_task = None 25 | self.known_peers: Dict[str, PeerHandle] = {} 26 | 27 | self._cached_peers: Dict[str, PeerConfig] = {} 28 | self._last_modified_time: Optional[float] = None 29 | self._file_executor = ThreadPoolExecutor(max_workers=1) 30 | 31 | async def start(self) -> None: 32 | self.listen_task = asyncio.create_task(self.task_find_peers_from_config()) 33 | 34 | async def stop(self) -> None: 35 | if self.listen_task: self.listen_task.cancel() 36 | self._file_executor.shutdown(wait=True) 37 | 38 | async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: 39 | if wait_for_peers > 0: 40 | while len(self.known_peers) < wait_for_peers: 41 | if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...") 42 | await asyncio.sleep(0.1) 43 | if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}") 44 | return list(self.known_peers.values()) 45 | 46 | async def task_find_peers_from_config(self): 47 | if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...") 48 | while True: 49 | peers_from_config = await self._get_peers() 50 | new_known_peers = {} 51 | for peer_id, peer_config in peers_from_config.items(): 52 | try: 53 | if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}") 54 | peer = self.known_peers.get(peer_id) 55 | if not peer: 56 | if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.") 57 | peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities) 58 | is_healthy = await peer.health_check() 59 | if is_healthy: 60 | if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.") 61 | new_known_peers[peer_id] = peer 62 | elif DEBUG_DISCOVERY >= 2: 63 | print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.") 64 | except Exception as e: 65 | if DEBUG_DISCOVERY >= 2: print(f"Exception occurred when attempting to add {peer_id=}: {e}") 66 | self.known_peers = new_known_peers 67 | await asyncio.sleep(5.0) 68 | 69 | if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}") 70 | 71 | async def _get_peers(self): 72 | try: 73 | loop = asyncio.get_running_loop() 74 | current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path) 75 | 76 | if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time): 77 | return self._cached_peers 78 | 79 | topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path) 80 | 81 | if self.node_id not in topology.peers: 82 | raise ValueError( 83 | f"Node ID {self.node_id} not found in network config file " 84 | f"{self.network_config_path}. Please run with `node_id` set to " 85 | f"one of the keys in the config file: {[k for k, _ in topology.peers]}" 86 | ) 87 | 88 | peers_in_network = topology.peers 89 | peers_in_network.pop(self.node_id) 90 | 91 | self._cached_peers = peers_in_network 92 | self._last_modified_time = current_mtime 93 | 94 | return peers_in_network 95 | 96 | except Exception as e: 97 | if DEBUG_DISCOVERY >= 2: 98 | print(f"Error when loading network config file from {self.network_config_path}. " 99 | f"Please update the config file in order to successfully discover peers. " 100 | f"Exception: {e}") 101 | return self._cached_peers 102 | -------------------------------------------------------------------------------- /exo/networking/manual/network_topology_config.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from pydantic import BaseModel, ValidationError 3 | 4 | from exo.topology.device_capabilities import DeviceCapabilities 5 | 6 | 7 | class PeerConfig(BaseModel): 8 | address: str 9 | port: int 10 | device_capabilities: DeviceCapabilities 11 | 12 | 13 | class NetworkTopology(BaseModel): 14 | """Configuration of the network. A collection outlining all nodes in the network, including the node this is running from.""" 15 | 16 | peers: Dict[str, PeerConfig] 17 | """ 18 | node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict. 19 | """ 20 | @classmethod 21 | def from_path(cls, path: str) -> "NetworkTopology": 22 | try: 23 | with open(path, "r") as f: 24 | config_data = f.read() 25 | except FileNotFoundError as e: 26 | raise FileNotFoundError(f"Config file not found at {path}") from e 27 | 28 | try: 29 | return cls.model_validate_json(config_data) 30 | except ValidationError as e: 31 | raise ValueError(f"Error validating network topology config from {path}: {e}") from e 32 | -------------------------------------------------------------------------------- /exo/networking/manual/test_data/invalid_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "peers": { 3 | "node1": { 4 | "address": "localhost", 5 | "device_capabilities": { 6 | "model": "Unknown Model", 7 | "chip": "Unknown Chip", 8 | "memory": 0, 9 | "flops": { 10 | "fp32": 0, 11 | "fp16": 0, 12 | "int8": 0 13 | } 14 | } 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /exo/networking/manual/test_data/invalid_json.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/networking/manual/test_data/invalid_json.json -------------------------------------------------------------------------------- /exo/networking/manual/test_data/test_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "peers": { 3 | "node1": { 4 | "address": "localhost", 5 | "port": 50051, 6 | "device_capabilities": { 7 | "model": "Unknown Model", 8 | "chip": "Unknown Chip", 9 | "memory": 0, 10 | "flops": { 11 | "fp32": 0, 12 | "fp16": 0, 13 | "int8": 0 14 | } 15 | } 16 | }, 17 | "node2": { 18 | "address": "localhost", 19 | "port": 50052, 20 | "device_capabilities": { 21 | "model": "Unknown Model", 22 | "chip": "Unknown Chip", 23 | "memory": 0, 24 | "flops": { 25 | "fp32": 0, 26 | "fp16": 0, 27 | "int8": 0 28 | } 29 | } 30 | } 31 | } 32 | } -------------------------------------------------------------------------------- /exo/networking/manual/test_data/test_config_single_node.json: -------------------------------------------------------------------------------- 1 | { 2 | "peers": { 3 | "node1": { 4 | "address": "localhost", 5 | "port": 50051, 6 | "device_capabilities": { 7 | "model": "Unknown Model", 8 | "chip": "Unknown Chip", 9 | "memory": 0, 10 | "flops": { 11 | "fp32": 0, 12 | "fp16": 0, 13 | "int8": 0 14 | } 15 | } 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /exo/networking/manual/test_manual_discovery.py: -------------------------------------------------------------------------------- 1 | import json 2 | import asyncio 3 | import unittest 4 | from unittest import mock 5 | from exo.networking.manual.manual_discovery import ManualDiscovery 6 | from exo.networking.manual.network_topology_config import NetworkTopology 7 | from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle 8 | from exo.networking.grpc.grpc_server import GRPCServer 9 | from exo.orchestration.node import Node 10 | 11 | root_path = "./exo/networking/manual/test_data/test_config.json" 12 | 13 | 14 | class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase): 15 | async def asyncSetUp(self): 16 | self.peer1 = mock.AsyncMock() 17 | self.peer1.connect = mock.AsyncMock() 18 | self.discovery1 = ManualDiscovery( 19 | root_path, 20 | "node1", 21 | create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1, 22 | ) 23 | await self.discovery1.start() 24 | 25 | async def asyncTearDown(self): 26 | await self.discovery1.stop() 27 | 28 | async def test_discovery(self): 29 | peers1 = await self.discovery1.discover_peers(wait_for_peers=0) 30 | assert len(peers1) == 0 31 | 32 | self.peer1.connect.assert_not_called() 33 | 34 | 35 | class TestManualDiscovery(unittest.IsolatedAsyncioTestCase): 36 | async def asyncSetUp(self): 37 | self.peer1 = mock.AsyncMock() 38 | self.peer2 = mock.AsyncMock() 39 | self.peer1.connect = mock.AsyncMock() 40 | self.peer2.connect = mock.AsyncMock() 41 | self.discovery1 = ManualDiscovery( 42 | root_path, 43 | "node1", 44 | create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1, 45 | ) 46 | self.discovery2 = ManualDiscovery( 47 | root_path, 48 | "node2", 49 | create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2, 50 | ) 51 | await self.discovery1.start() 52 | await self.discovery2.start() 53 | 54 | async def asyncTearDown(self): 55 | await self.discovery1.stop() 56 | await self.discovery2.stop() 57 | 58 | async def test_discovery(self): 59 | peers1 = await self.discovery1.discover_peers(wait_for_peers=1) 60 | assert len(peers1) == 1 61 | peers2 = await self.discovery2.discover_peers(wait_for_peers=1) 62 | assert len(peers2) == 1 63 | 64 | # connect has to be explicitly called after discovery 65 | self.peer1.connect.assert_not_called() 66 | self.peer2.connect.assert_not_called() 67 | 68 | 69 | class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase): 70 | async def asyncSetUp(self): 71 | config = NetworkTopology.from_path(root_path) 72 | 73 | self.node1 = mock.AsyncMock(spec=Node) 74 | self.node2 = mock.AsyncMock(spec=Node) 75 | self.server1 = GRPCServer(self.node1, config.peers["node1"].address, config.peers["node1"].port) 76 | self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port) 77 | await self.server1.start() 78 | await self.server2.start() 79 | self.discovery1 = ManualDiscovery( 80 | root_path, 81 | "node1", 82 | create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities), 83 | ) 84 | self.discovery2 = ManualDiscovery( 85 | root_path, 86 | "node2", 87 | create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities), 88 | ) 89 | await self.discovery1.start() 90 | await self.discovery2.start() 91 | 92 | async def asyncTearDown(self): 93 | await self.discovery1.stop() 94 | await self.discovery2.stop() 95 | await self.server1.stop() 96 | await self.server2.stop() 97 | 98 | async def test_grpc_discovery(self): 99 | peers1 = await self.discovery1.discover_peers(wait_for_peers=1) 100 | assert len(peers1) == 1 101 | peers2 = await self.discovery2.discover_peers(wait_for_peers=1) 102 | assert len(peers2) == 1 103 | 104 | # Connect 105 | await peers1[0].connect() 106 | await peers2[0].connect() 107 | self.assertTrue(await peers1[0].is_connected()) 108 | self.assertTrue(await peers2[0].is_connected()) 109 | 110 | # Kill server1 111 | await self.server1.stop() 112 | 113 | self.assertTrue(await peers1[0].is_connected()) 114 | self.assertFalse(await peers2[0].is_connected()) 115 | 116 | # Kill server2 117 | await self.server2.stop() 118 | 119 | self.assertFalse(await peers1[0].is_connected()) 120 | self.assertFalse(await peers2[0].is_connected()) 121 | 122 | async def test_dynamic_config_update(self): 123 | initial_peers = await self.discovery1.discover_peers(wait_for_peers=1) 124 | self.assertEqual(len(initial_peers), 1) 125 | 126 | # Save original config for cleanup 127 | with open(root_path, "r") as f: 128 | original_config = json.load(f) 129 | 130 | try: 131 | updated_config = { 132 | "peers": { 133 | **original_config["peers"], 134 | "node3": { 135 | "address": "localhost", 136 | "port": 50053, 137 | "device_capabilities": { 138 | "model": "Unknown Model", 139 | "chip": "Unknown Chip", 140 | "memory": 0, 141 | "flops": {"fp32": 0, "fp16": 0, "int8": 0}, 142 | }, 143 | }, 144 | } 145 | } 146 | 147 | with open(root_path, "w") as f: 148 | json.dump(updated_config, f, indent=2) 149 | 150 | node3 = mock.AsyncMock(spec=Node) 151 | server3 = GRPCServer(node3, "localhost", 50053) 152 | await server3.start() 153 | 154 | try: 155 | # Wait for the config to be reloaded 156 | await asyncio.sleep(1.5) 157 | 158 | updated_peers = await self.discovery1.discover_peers(wait_for_peers=2) 159 | self.assertEqual(len(updated_peers), 2) 160 | 161 | for peer in updated_peers: 162 | await peer.connect() 163 | self.assertTrue(await peer.is_connected()) 164 | 165 | finally: 166 | await server3.stop() 167 | 168 | finally: 169 | # Restore the original config file 170 | with open(root_path, "w") as f: 171 | json.dump(original_config, f, indent=2) 172 | 173 | # Wait for the config to be reloaded again 174 | await asyncio.sleep(1.5) 175 | 176 | updated_peers = await self.discovery1.discover_peers(wait_for_peers=1) 177 | self.assertEqual(len(updated_peers), 1) 178 | 179 | 180 | if __name__ == "__main__": 181 | asyncio.run(unittest.main()) 182 | -------------------------------------------------------------------------------- /exo/networking/manual/test_network_topology_config.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from exo.networking.manual.network_topology_config import NetworkTopology 4 | 5 | root_path = "./exo/networking/manual/test_data/" 6 | 7 | 8 | class TestNetworkTopologyConfig(unittest.TestCase): 9 | def test_from_path_invalid_path(self): 10 | with self.assertRaises(FileNotFoundError) as e: 11 | NetworkTopology.from_path("invalid_path") 12 | self.assertEqual(str(e.exception), "Config file not found at invalid_path") 13 | 14 | def test_from_path_invalid_json(self): 15 | with self.assertRaises(ValueError) as e: 16 | NetworkTopology.from_path(root_path + "invalid_json.json") 17 | self.assertIn("Error validating network topology config from", str(e.exception)) 18 | self.assertIn("1 validation error for NetworkTopology\n Invalid JSON: EOF while parsing a value at line 1 column 0", str(e.exception)) 19 | 20 | def test_from_path_invalid_config(self): 21 | with self.assertRaises(ValueError) as e: 22 | NetworkTopology.from_path(root_path + "invalid_config.json") 23 | self.assertIn("Error validating network topology config from", str(e.exception)) 24 | self.assertIn("port\n Field required", str(e.exception)) 25 | 26 | def test_from_path_valid(self): 27 | config = NetworkTopology.from_path(root_path + "test_config.json") 28 | 29 | self.assertEqual(config.peers["node1"].port, 50051) 30 | self.assertEqual(config.peers["node1"].device_capabilities.model, "Unknown Model") 31 | self.assertEqual(config.peers["node1"].address, "localhost") 32 | self.assertEqual(config.peers["node1"].device_capabilities.chip, "Unknown Chip") 33 | self.assertEqual(config.peers["node1"].device_capabilities.memory, 0) 34 | self.assertEqual(config.peers["node1"].device_capabilities.flops.fp32, 0) 35 | self.assertEqual(config.peers["node1"].device_capabilities.flops.fp16, 0) 36 | self.assertEqual(config.peers["node1"].device_capabilities.flops.int8, 0) 37 | 38 | self.assertEqual(config.peers["node2"].port, 50052) 39 | self.assertEqual(config.peers["node2"].device_capabilities.model, "Unknown Model") 40 | self.assertEqual(config.peers["node2"].address, "localhost") 41 | self.assertEqual(config.peers["node2"].device_capabilities.chip, "Unknown Chip") 42 | self.assertEqual(config.peers["node2"].device_capabilities.memory, 0) 43 | self.assertEqual(config.peers["node2"].device_capabilities.flops.fp32, 0) 44 | self.assertEqual(config.peers["node2"].device_capabilities.flops.fp16, 0) 45 | self.assertEqual(config.peers["node2"].device_capabilities.flops.int8, 0) 46 | 47 | 48 | if __name__ == "__main__": 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /exo/networking/peer_handle.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Tuple, List 3 | import numpy as np 4 | from exo.inference.shard import Shard 5 | from exo.topology.device_capabilities import DeviceCapabilities 6 | from exo.topology.topology import Topology 7 | 8 | 9 | class PeerHandle(ABC): 10 | @abstractmethod 11 | def id(self) -> str: 12 | pass 13 | 14 | @abstractmethod 15 | def addr(self) -> str: 16 | pass 17 | 18 | @abstractmethod 19 | def description(self) -> str: 20 | pass 21 | 22 | @abstractmethod 23 | def device_capabilities(self) -> DeviceCapabilities: 24 | pass 25 | 26 | @abstractmethod 27 | async def connect(self) -> None: 28 | pass 29 | 30 | @abstractmethod 31 | async def is_connected(self) -> bool: 32 | pass 33 | 34 | @abstractmethod 35 | async def disconnect(self) -> None: 36 | pass 37 | 38 | @abstractmethod 39 | async def health_check(self) -> bool: 40 | pass 41 | 42 | @abstractmethod 43 | async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]: 44 | pass 45 | 46 | @abstractmethod 47 | async def send_tensor(self, shard: Shard, tensor: np.array, request_id: Optional[str] = None) -> Optional[np.array]: 48 | pass 49 | 50 | @abstractmethod 51 | async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None: 52 | pass 53 | 54 | @abstractmethod 55 | async def collect_topology(self, visited: set[str], max_depth: int) -> Topology: 56 | pass 57 | -------------------------------------------------------------------------------- /exo/networking/server.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class Server(ABC): 5 | @abstractmethod 6 | async def start(self) -> None: 7 | pass 8 | 9 | @abstractmethod 10 | async def stop(self) -> None: 11 | pass 12 | -------------------------------------------------------------------------------- /exo/networking/tailscale/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/networking/tailscale/__init__.py -------------------------------------------------------------------------------- /exo/networking/tailscale/tailscale_helpers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import asyncio 3 | import aiohttp 4 | import re 5 | from typing import Dict, Any, Tuple, List, Optional 6 | from exo.helpers import DEBUG_DISCOVERY 7 | from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops 8 | from datetime import datetime, timezone 9 | 10 | 11 | class Device: 12 | def __init__(self, device_id: str, name: str, addresses: List[str], last_seen: Optional[datetime] = None): 13 | self.device_id = device_id 14 | self.name = name 15 | self.addresses = addresses 16 | self.last_seen = last_seen 17 | 18 | @classmethod 19 | def from_dict(cls, data: Dict[str, Any]) -> 'Device': 20 | return cls(device_id=data.get('id', ''), name=data.get('name', ''), addresses=data.get('addresses', []), last_seen=cls.parse_datetime(data.get('lastSeen'))) 21 | 22 | @staticmethod 23 | def parse_datetime(date_string: Optional[str]) -> Optional[datetime]: 24 | if not date_string: 25 | return None 26 | return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) 27 | 28 | 29 | async def get_device_id() -> str: 30 | try: 31 | process = await asyncio.create_subprocess_exec('tailscale', 'status', '--json', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) 32 | stdout, stderr = await process.communicate() 33 | if process.returncode != 0: 34 | raise Exception(f"Command failed with exit code {process.returncode}: {stderr.decode().strip()}.") 35 | if DEBUG_DISCOVERY >= 4: print(f"tailscale status: {stdout.decode()}") 36 | data = json.loads(stdout.decode()) 37 | return data['Self']['ID'] 38 | except Exception as e: 39 | raise Exception(f"{str(e)} Do you have the tailscale cli installed? See: https://tailscale.com/kb/1080/cli") 40 | 41 | 42 | async def update_device_attributes(device_id: str, api_key: str, node_id: str, node_port: int, device_capabilities: DeviceCapabilities): 43 | async with aiohttp.ClientSession() as session: 44 | base_url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes" 45 | headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'} 46 | 47 | attributes = { 48 | "custom:exo_node_id": node_id.replace('-', '_'), "custom:exo_node_port": node_port, "custom:exo_device_capability_chip": sanitize_attribute(device_capabilities.chip), 49 | "custom:exo_device_capability_model": sanitize_attribute(device_capabilities.model), "custom:exo_device_capability_memory": str(device_capabilities.memory), 50 | "custom:exo_device_capability_flops_fp16": str(device_capabilities.flops.fp16), "custom:exo_device_capability_flops_fp32": str(device_capabilities.flops.fp32), 51 | "custom:exo_device_capability_flops_int8": str(device_capabilities.flops.int8) 52 | } 53 | 54 | for attr_name, attr_value in attributes.items(): 55 | url = f"{base_url}/{attr_name}" 56 | data = {"value": str(attr_value).replace(' ', '_')} # Ensure all values are strings for JSON 57 | async with session.post(url, headers=headers, json=data) as response: 58 | if response.status == 200: 59 | if DEBUG_DISCOVERY >= 1: print(f"Updated device posture attribute {attr_name} for device {device_id}") 60 | else: 61 | print(f"Failed to update device posture attribute {attr_name}: {response.status} {await response.text()}") 62 | 63 | 64 | async def get_device_attributes(device_id: str, api_key: str) -> Tuple[str, int, DeviceCapabilities]: 65 | async with aiohttp.ClientSession() as session: 66 | url = f"https://api.tailscale.com/api/v2/device/{device_id}/attributes" 67 | headers = {'Authorization': f'Bearer {api_key}'} 68 | async with session.get(url, headers=headers) as response: 69 | if response.status == 200: 70 | data = await response.json() 71 | attributes = data.get("attributes", {}) 72 | node_id = attributes.get("custom:exo_node_id", "").replace('_', '-') 73 | node_port = int(attributes.get("custom:exo_node_port", 0)) 74 | device_capabilities = DeviceCapabilities( 75 | model=attributes.get("custom:exo_device_capability_model", "").replace('_', ' '), 76 | chip=attributes.get("custom:exo_device_capability_chip", "").replace('_', ' '), 77 | memory=int(attributes.get("custom:exo_device_capability_memory", 0)), 78 | flops=DeviceFlops( 79 | fp16=float(attributes.get("custom:exo_device_capability_flops_fp16", 0)), 80 | fp32=float(attributes.get("custom:exo_device_capability_flops_fp32", 0)), 81 | int8=float(attributes.get("custom:exo_device_capability_flops_int8", 0)) 82 | ) 83 | ) 84 | return node_id, node_port, device_capabilities 85 | else: 86 | print(f"Failed to fetch posture attributes for {device_id}: {response.status}") 87 | return "", 0, DeviceCapabilities(model="", chip="", memory=0, flops=DeviceFlops(fp16=0, fp32=0, int8=0)) 88 | 89 | 90 | def parse_device_attributes(data: Dict[str, str]) -> Dict[str, Any]: 91 | result = {} 92 | prefix = "custom:exo_" 93 | for key, value in data.items(): 94 | if key.startswith(prefix): 95 | attr_name = key.replace(prefix, "") 96 | if attr_name in ["node_id", "node_port", "device_capability_chip", "device_capability_model"]: 97 | result[attr_name] = value.replace('_', ' ') 98 | elif attr_name in ["device_capability_memory", "device_capability_flops_fp16", "device_capability_flops_fp32", "device_capability_flops_int8"]: 99 | result[attr_name] = float(value) 100 | return result 101 | 102 | 103 | def sanitize_attribute(value: str) -> str: 104 | # Replace invalid characters with underscores 105 | sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value) 106 | # Truncate to 50 characters 107 | return sanitized_value[:50] 108 | 109 | 110 | async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]: 111 | async with aiohttp.ClientSession() as session: 112 | url = f"https://api.tailscale.com/api/v2/tailnet/{tailnet}/devices" 113 | headers = {"Authorization": f"Bearer {api_key}"} 114 | 115 | async with session.get(url, headers=headers) as response: 116 | response.raise_for_status() 117 | data = await response.json() 118 | 119 | devices = {} 120 | for device_data in data.get("devices", []): 121 | print("Device data: ", device_data) 122 | device = Device.from_dict(device_data) 123 | devices[device.name] = device 124 | 125 | return devices 126 | -------------------------------------------------------------------------------- /exo/networking/tailscale/test_tailscale_discovery.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | import unittest 4 | from unittest import mock 5 | from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery 6 | from exo.networking.peer_handle import PeerHandle 7 | 8 | 9 | class TestTailscaleDiscovery(unittest.IsolatedAsyncioTestCase): 10 | async def asyncSetUp(self): 11 | self.tailscale_api_key = os.environ.get("TAILSCALE_API_KEY", "") 12 | self.tailnet = os.environ.get("TAILSCALE_TAILNET", "") 13 | self.discovery = TailscaleDiscovery( 14 | node_id="test_node", 15 | node_port=50051, 16 | create_peer_handle=lambda peer_id, address, description, device_capabilities: unittest.mock.Mock(spec=PeerHandle, id=lambda: peer_id), 17 | tailscale_api_key=self.tailscale_api_key, 18 | tailnet=self.tailnet 19 | ) 20 | await self.discovery.start() 21 | 22 | async def asyncTearDown(self): 23 | await self.discovery.stop() 24 | 25 | async def test_discovery(self): 26 | # Wait for a short period to allow discovery to happen 27 | await asyncio.sleep(15) 28 | 29 | # Get discovered peers 30 | peers = await self.discovery.discover_peers() 31 | 32 | # Check if any peers were discovered 33 | self.assertGreater(len(peers), 0, "No peers were discovered") 34 | 35 | # Print discovered peers for debugging 36 | print(f"Discovered peers: {[peer.id() for peer in peers]}") 37 | 38 | # Check if discovered peers are instances of GRPCPeerHandle 39 | print(peers) 40 | 41 | 42 | if __name__ == '__main__': 43 | unittest.main() 44 | -------------------------------------------------------------------------------- /exo/networking/udp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/networking/udp/__init__.py -------------------------------------------------------------------------------- /exo/networking/udp/test_udp_discovery.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import unittest 3 | from unittest import mock 4 | from exo.networking.udp.udp_discovery import UDPDiscovery 5 | from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle 6 | from exo.networking.grpc.grpc_server import GRPCServer 7 | from exo.orchestration.node import Node 8 | 9 | 10 | class TestUDPDiscovery(unittest.IsolatedAsyncioTestCase): 11 | async def asyncSetUp(self): 12 | self.peer1 = mock.AsyncMock() 13 | self.peer2 = mock.AsyncMock() 14 | self.peer1.connect = mock.AsyncMock() 15 | self.peer2.connect = mock.AsyncMock() 16 | self.discovery1 = UDPDiscovery("discovery1", 50051, 5678, 5679, create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1) 17 | self.discovery2 = UDPDiscovery("discovery2", 50052, 5679, 5678, create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2) 18 | await self.discovery1.start() 19 | await self.discovery2.start() 20 | 21 | async def asyncTearDown(self): 22 | await self.discovery1.stop() 23 | await self.discovery2.stop() 24 | 25 | async def test_discovery(self): 26 | peers1 = await self.discovery1.discover_peers(wait_for_peers=1) 27 | assert len(peers1) == 1 28 | peers2 = await self.discovery2.discover_peers(wait_for_peers=1) 29 | assert len(peers2) == 1 30 | 31 | # connect has to be explicitly called after discovery 32 | self.peer1.connect.assert_not_called() 33 | self.peer2.connect.assert_not_called() 34 | 35 | 36 | class TestUDPDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase): 37 | async def asyncSetUp(self): 38 | self.node1 = mock.AsyncMock(spec=Node) 39 | self.node2 = mock.AsyncMock(spec=Node) 40 | self.server1 = GRPCServer(self.node1, "localhost", 50053) 41 | self.server2 = GRPCServer(self.node2, "localhost", 50054) 42 | await self.server1.start() 43 | await self.server2.start() 44 | self.discovery1 = UDPDiscovery("discovery1", 50053, 5678, 5679, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities)) 45 | self.discovery2 = UDPDiscovery("discovery2", 50054, 5679, 5678, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities)) 46 | await self.discovery1.start() 47 | await self.discovery2.start() 48 | 49 | async def asyncTearDown(self): 50 | await self.discovery1.stop() 51 | await self.discovery2.stop() 52 | await self.server1.stop() 53 | await self.server2.stop() 54 | 55 | async def test_grpc_discovery(self): 56 | peers1 = await self.discovery1.discover_peers(wait_for_peers=1) 57 | assert len(peers1) == 1 58 | peers2 = await self.discovery2.discover_peers(wait_for_peers=1) 59 | assert len(peers2) == 1 60 | assert not await peers1[0].is_connected() 61 | assert not await peers2[0].is_connected() 62 | 63 | # Connect 64 | await peers1[0].connect() 65 | await peers2[0].connect() 66 | assert await peers1[0].is_connected() 67 | assert await peers2[0].is_connected() 68 | 69 | # Kill server1 70 | await self.server1.stop() 71 | 72 | assert await peers1[0].is_connected() 73 | assert not await peers2[0].is_connected() 74 | 75 | 76 | if __name__ == "__main__": 77 | asyncio.run(unittest.main()) 78 | -------------------------------------------------------------------------------- /exo/orchestration/__init__.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | __all__ = ["Node"] 4 | -------------------------------------------------------------------------------- /exo/orchestration/test_node.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import Mock, AsyncMock 3 | import numpy as np 4 | import pytest 5 | 6 | from .node import Node 7 | from exo.networking.peer_handle import PeerHandle 8 | from exo.download.shard_download import NoopShardDownloader 9 | 10 | class TestNode(unittest.IsolatedAsyncioTestCase): 11 | def setUp(self): 12 | self.mock_inference_engine = AsyncMock() 13 | self.mock_server = AsyncMock() 14 | self.mock_server.start = AsyncMock() 15 | self.mock_server.stop = AsyncMock() 16 | self.mock_discovery = AsyncMock() 17 | self.mock_discovery.start = AsyncMock() 18 | self.mock_discovery.stop = AsyncMock() 19 | mock_peer1 = Mock(spec=PeerHandle) 20 | mock_peer1.id.return_value = "peer1" 21 | mock_peer2 = Mock(spec=PeerHandle) 22 | mock_peer2.id.return_value = "peer2" 23 | self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2]) 24 | 25 | self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery, NoopShardDownloader()) 26 | 27 | async def asyncSetUp(self): 28 | await self.node.start() 29 | 30 | async def asyncTearDown(self): 31 | await self.node.stop() 32 | 33 | async def test_node_initialization(self): 34 | self.assertEqual(self.node.node_id, "test_node") 35 | self.assertEqual(self.node.host, "localhost") 36 | self.assertEqual(self.node.port, 50051) 37 | 38 | async def test_node_start(self): 39 | self.mock_server.start.assert_called_once_with("localhost", 50051) 40 | 41 | async def test_node_stop(self): 42 | await self.node.stop() 43 | self.mock_server.stop.assert_called_once() 44 | 45 | async def test_discover_and_connect_to_peers(self): 46 | await self.node.discover_and_connect_to_peers() 47 | self.assertEqual(len(self.node.peers), 2) 48 | self.assertIn("peer1", map(lambda p: p.id(), self.node.peers)) 49 | self.assertIn("peer2", map(lambda p: p.id(), self.node.peers)) 50 | 51 | async def test_process_tensor_calls_inference_engine(self): 52 | mock_peer = Mock() 53 | self.node.peers = [mock_peer] 54 | 55 | input_tensor = np.array([69, 1, 2]) 56 | await self.node.process_tensor(input_tensor, None) 57 | 58 | self.node.inference_engine.process_shard.assert_called_once_with(input_tensor) 59 | 60 | @pytest.mark.asyncio 61 | async def test_node_capabilities(): 62 | node = Node() 63 | await node.initialize() 64 | caps = await node.get_device_capabilities() 65 | assert caps is not None 66 | assert caps.model != "" 67 | -------------------------------------------------------------------------------- /exo/orchestration/tracing.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, Optional, Any 3 | from opentelemetry import trace, context 4 | from opentelemetry.trace import Status, StatusCode, SpanContext 5 | from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator 6 | from contextlib import contextmanager 7 | import time 8 | from threading import Lock 9 | 10 | @dataclass 11 | class TraceContext: 12 | request_id: str 13 | sequence_number: int 14 | current_span: Optional[trace.Span] = None 15 | trace_parent: Optional[str] = None 16 | token_group_span: Optional[trace.Span] = None 17 | token_count: int = 0 18 | token_group_size: int = 10 # Default group size 19 | request_span: Optional[trace.Span] = None # Track the main request span 20 | 21 | class Tracer: 22 | def __init__(self): 23 | self.tracer = trace.get_tracer("exo") 24 | self.contexts: Dict[str, TraceContext] = {} 25 | self._lock = Lock() 26 | self.propagator = TraceContextTextMapPropagator() 27 | 28 | def get_context(self, request_id: str) -> Optional[TraceContext]: 29 | with self._lock: 30 | return self.contexts.get(request_id) 31 | 32 | def set_context(self, request_id: str, context: TraceContext): 33 | with self._lock: 34 | self.contexts[request_id] = context 35 | 36 | def inject_context(self, span: trace.Span) -> str: 37 | """Inject current span context into carrier for propagation""" 38 | carrier = {} 39 | ctx = trace.set_span_in_context(span) 40 | self.propagator.inject(carrier, context=ctx) 41 | return carrier.get("traceparent", "") 42 | 43 | def extract_context(self, trace_parent: str) -> Optional[context.Context]: 44 | """Extract span context from carrier""" 45 | if not trace_parent: 46 | return None 47 | carrier = {"traceparent": trace_parent} 48 | return self.propagator.extract(carrier) 49 | 50 | def create_context_from_parent(self, request_id: str, trace_parent: str, sequence_number: int = 0) -> TraceContext: 51 | """Create a new context with the given trace parent""" 52 | parent_ctx = self.extract_context(trace_parent) 53 | if parent_ctx: 54 | # Create a new request span that links to the parent context 55 | request_span = self.tracer.start_span( 56 | "request", 57 | context=parent_ctx, 58 | attributes={ 59 | "request_id": request_id, 60 | "sequence_number": sequence_number 61 | } 62 | ) 63 | return TraceContext( 64 | request_id=request_id, 65 | sequence_number=sequence_number, 66 | request_span=request_span, 67 | current_span=request_span, 68 | trace_parent=trace_parent 69 | ) 70 | return TraceContext(request_id=request_id, sequence_number=sequence_number) 71 | 72 | def handle_token(self, context: TraceContext, token: int, is_finished: bool = False): 73 | """Handle token generation and manage token group spans""" 74 | context.token_count += 1 75 | 76 | # Start a new token group span if needed 77 | if not context.token_group_span and context.request_span: 78 | group_number = (context.token_count - 1) // context.token_group_size + 1 79 | 80 | # Create token group span as child of request span 81 | parent_ctx = trace.set_span_in_context(context.request_span) 82 | context.token_group_span = self.tracer.start_span( 83 | f"token_group_{group_number}", 84 | context=parent_ctx, 85 | attributes={ 86 | "request_id": context.request_id, 87 | "group.number": group_number, 88 | "group.start_token": context.token_count, 89 | "group.max_tokens": context.token_group_size 90 | } 91 | ) 92 | 93 | # Add token to current group span 94 | if context.token_group_span: 95 | relative_pos = ((context.token_count - 1) % context.token_group_size) + 1 96 | context.token_group_span.set_attribute(f"token.{relative_pos}", token) 97 | context.token_group_span.set_attribute("token.count", relative_pos) 98 | 99 | # End current group span if we've reached the group size or if generation is finished 100 | if context.token_count % context.token_group_size == 0 or is_finished: 101 | context.token_group_span.set_attribute("token.final_count", relative_pos) 102 | context.token_group_span.end() 103 | context.token_group_span = None 104 | 105 | @contextmanager 106 | def start_span(self, name: str, context: TraceContext, extra_attributes: Optional[Dict[str, Any]] = None): 107 | """Start a new span with proper parent context""" 108 | attributes = { 109 | "request_id": context.request_id, 110 | "sequence_number": context.sequence_number 111 | } 112 | if extra_attributes: 113 | attributes.update(extra_attributes) 114 | 115 | # Use request span as parent if available 116 | parent_ctx = None 117 | if context.request_span: 118 | parent_ctx = trace.set_span_in_context(context.request_span) 119 | elif context.trace_parent: 120 | parent_ctx = self.extract_context(context.trace_parent) 121 | if parent_ctx and not context.request_span: 122 | # Create a new request span that links to the parent context 123 | context.request_span = self.tracer.start_span( 124 | "request", 125 | context=parent_ctx, 126 | attributes={ 127 | "request_id": context.request_id, 128 | "sequence_number": context.sequence_number 129 | } 130 | ) 131 | parent_ctx = trace.set_span_in_context(context.request_span) 132 | elif context.current_span: 133 | parent_ctx = trace.set_span_in_context(context.current_span) 134 | 135 | # Create span with parent context if it exists 136 | if parent_ctx: 137 | span = self.tracer.start_span( 138 | name, 139 | context=parent_ctx, 140 | attributes=attributes 141 | ) 142 | else: 143 | span = self.tracer.start_span( 144 | name, 145 | attributes=attributes 146 | ) 147 | 148 | # Update context with current span 149 | prev_span = context.current_span 150 | context.current_span = span 151 | 152 | try: 153 | start_time = time.perf_counter() 154 | yield span 155 | duration = time.perf_counter() - start_time 156 | span.set_attribute("duration_s", duration) 157 | span.set_status(Status(StatusCode.OK)) 158 | except Exception as e: 159 | span.set_status(Status(StatusCode.ERROR, str(e))) 160 | raise 161 | finally: 162 | span.end() 163 | context.current_span = prev_span 164 | 165 | # Global tracer instance 166 | tracer = Tracer() -------------------------------------------------------------------------------- /exo/test_callbacks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any, Callable 3 | from exo.helpers import AsyncCallbackSystem, AsyncCallback 4 | 5 | 6 | # Usage example 7 | async def main() -> None: 8 | callback_system = AsyncCallbackSystem[str, Any]() 9 | 10 | # Register callbacks 11 | callback1 = callback_system.register("callback1") 12 | callback2 = callback_system.register("callback2") 13 | 14 | def on_next_callback(name: str) -> Callable[..., None]: 15 | def callback(*args: Any) -> None: 16 | print(f"{name} received values: {args}") 17 | 18 | return callback 19 | 20 | callback1.on_next(on_next_callback("Callback1")) 21 | callback2.on_next(on_next_callback("Callback2")) 22 | 23 | async def wait_for_callback(name: str, callback: AsyncCallback[Any], condition: Callable[..., bool]) -> None: 24 | try: 25 | result = await callback.wait(condition, timeout=2) 26 | print(f"{name} wait completed with result: {result}") 27 | except asyncio.TimeoutError: 28 | print(f"{name} wait timed out") 29 | 30 | # Trigger all callbacks at once 31 | callback_system.trigger_all("Hello", 42, True) 32 | 33 | # Wait for all callbacks with different conditions 34 | await asyncio.gather( 35 | wait_for_callback("Callback1", callback1, lambda msg, num, flag: isinstance(msg, str) and num > 0), 36 | wait_for_callback("Callback2", callback2, lambda msg, num, flag: flag is True), 37 | ) 38 | 39 | # Trigger individual callback 40 | callback_system.trigger("callback2", "World", -10, False) 41 | 42 | # Demonstrate timeout 43 | new_callback = callback_system.register("new_callback") 44 | new_callback.on_next(on_next_callback("NewCallback")) 45 | await wait_for_callback("NewCallback", new_callback, lambda msg, num, flag: num > 100) 46 | 47 | callback_system.trigger("callback2", "World", 200, False) 48 | 49 | 50 | asyncio.run(main()) 51 | -------------------------------------------------------------------------------- /exo/tinychat/common.css: -------------------------------------------------------------------------------- 1 | /* make it responsive */ 2 | @media(min-width: 852px) { 3 | body { 4 | font-size: 14px; 5 | } 6 | } 7 | @media(max-width: 852px) { 8 | body { 9 | font-size: 12px; 10 | } 11 | } 12 | 13 | /* resets */ 14 | html, body { 15 | width: 100%; 16 | height: 100%; 17 | } 18 | 19 | *::-webkit-scrollbar { 20 | display: none; 21 | } 22 | 23 | * { 24 | -ms-overflow-style: none; 25 | scrollbar-width: none; 26 | } 27 | 28 | * { 29 | -moz-box-sizing: border-box; 30 | -webkit-box-sizing: border-box; 31 | box-sizing: border-box; 32 | } 33 | 34 | /* default */ 35 | body { 36 | margin: 0; 37 | background-color: var(--primary-bg-color); 38 | color: var(--foreground-color); 39 | } 40 | 41 | h1, h2, h3, h4, h5, h6 { 42 | margin: 0em; 43 | } 44 | 45 | hr { 46 | width: 92%; 47 | } 48 | 49 | button { 50 | cursor: pointer; 51 | border: none; 52 | background-color: transparent; 53 | } 54 | button:hover { 55 | } 56 | button:active { 57 | } 58 | 59 | /* components */ 60 | .container { 61 | margin: 0 auto; 62 | padding: 1rem; 63 | } 64 | 65 | .centered { 66 | display: flex; 67 | flex-direction: column; 68 | justify-content: center; 69 | align-items: center; 70 | } 71 | 72 | .centered-w-only { 73 | position: absolute; 74 | left: 50%; 75 | transform: translateX(-50%); 76 | } 77 | 78 | .centered-h-only { 79 | position: absolute; 80 | top: 50%; 81 | transform: translateY(-50%); 82 | } 83 | 84 | .card { 85 | padding: 0; 86 | } 87 | 88 | .card-header { 89 | padding: 0.5rem 1rem; 90 | } 91 | 92 | .card-container { 93 | width: 96vw; 94 | height: 100%; 95 | gap: 1rem; 96 | display: flex; 97 | flex-direction: row; 98 | flex-wrap: wrap; 99 | justify-content: center; 100 | align-items: center; 101 | } 102 | 103 | .clean-a { 104 | text-decoration: underline; 105 | text-decoration-color: #006fc1; 106 | text-decoration-thickness: 2px; 107 | color: inherit; 108 | } 109 | 110 | .hover-underline { 111 | text-decoration: underline; 112 | text-decoration-color: #228039; 113 | text-decoration-thickness: 2px; 114 | color: inherit; 115 | } 116 | 117 | .flex-horizontal { 118 | display: flex; 119 | flex-direction: row; 120 | justify-content: space-between; 121 | align-items: center; 122 | } 123 | 124 | .vertical-separator { 125 | padding: 0 0.5rem; 126 | } 127 | 128 | [x-cloak] { 129 | display: none !important; 130 | } 131 | -------------------------------------------------------------------------------- /exo/tinychat/favicon.svg: -------------------------------------------------------------------------------- 1 | 2 | 13 | 25 | 26 | -------------------------------------------------------------------------------- /exo/tinychat/static/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js: -------------------------------------------------------------------------------- 1 | (()=>{var H=Object.create,v=Object.defineProperty,N=Object.getPrototypeOf,V=Object.prototype.hasOwnProperty,z=Object.getOwnPropertyNames,q=Object.getOwnPropertyDescriptor;var W=n=>v(n,"__esModule",{value:!0});var D=(n,e)=>()=>(e||(e={exports:{}},n(e.exports,e)),e.exports);var F=(n,e,o)=>{if(e&&typeof e=="object"||typeof e=="function")for(let r of z(e))!V.call(n,r)&&r!=="default"&&v(n,r,{get:()=>e[r],enumerable:!(o=q(e,r))||o.enumerable});return n},U=n=>F(W(v(n!=null?H(N(n)):{},"default",n&&n.__esModule&&"default"in n?{get:()=>n.default,enumerable:!0}:{value:n,enumerable:!0})),n);var I=D((E,w)=>{(function(){"use strict";function n(){var e=window,o=document;if("scrollBehavior"in o.documentElement.style&&e.__forceSmoothScrollPolyfill__!==!0)return;var r=e.HTMLElement||e.Element,i=468,f={scroll:e.scroll||e.scrollTo,scrollBy:e.scrollBy,elementScroll:r.prototype.scroll||b,scrollIntoView:r.prototype.scrollIntoView},u=e.performance&&e.performance.now?e.performance.now.bind(e.performance):Date.now;function c(t){var l=["MSIE ","Trident/","Edge/"];return new RegExp(l.join("|")).test(t)}var g=c(e.navigator.userAgent)?1:0;function b(t,l){this.scrollLeft=t,this.scrollTop=l}function M(t){return .5*(1-Math.cos(Math.PI*t))}function m(t){if(t===null||typeof t!="object"||t.behavior===void 0||t.behavior==="auto"||t.behavior==="instant")return!0;if(typeof t=="object"&&t.behavior==="smooth")return!1;throw new TypeError("behavior member of ScrollOptions "+t.behavior+" is not a valid value for enumeration ScrollBehavior.")}function O(t,l){if(l==="Y")return t.clientHeight+g1?1:a,s=M(a),d=t.startX+(t.x-t.startX)*s,p=t.startY+(t.y-t.startY)*s,t.method.call(t.scrollable,d,p),(d!==t.x||p!==t.y)&&e.requestAnimationFrame(S.bind(e,t))}function h(t,l,s){var d,p,a,y,_=u();t===o.body?(d=e,p=e.scrollX||e.pageXOffset,a=e.scrollY||e.pageYOffset,y=f.scroll):(d=t,p=t.scrollLeft,a=t.scrollTop,y=b),S({scrollable:d,method:y,startTime:_,startX:p,startY:a,x:l,y:s})}e.scroll=e.scrollTo=function(){if(arguments[0]!==void 0){if(m(arguments[0])===!0){f.scroll.call(e,arguments[0].left!==void 0?arguments[0].left:typeof arguments[0]!="object"?arguments[0]:e.scrollX||e.pageXOffset,arguments[0].top!==void 0?arguments[0].top:arguments[1]!==void 0?arguments[1]:e.scrollY||e.pageYOffset);return}h.call(e,o.body,arguments[0].left!==void 0?~~arguments[0].left:e.scrollX||e.pageXOffset,arguments[0].top!==void 0?~~arguments[0].top:e.scrollY||e.pageYOffset)}},e.scrollBy=function(){if(arguments[0]!==void 0){if(m(arguments[0])){f.scrollBy.call(e,arguments[0].left!==void 0?arguments[0].left:typeof arguments[0]!="object"?arguments[0]:0,arguments[0].top!==void 0?arguments[0].top:arguments[1]!==void 0?arguments[1]:0);return}h.call(e,o.body,~~arguments[0].left+(e.scrollX||e.pageXOffset),~~arguments[0].top+(e.scrollY||e.pageYOffset))}},r.prototype.scroll=r.prototype.scrollTo=function(){if(arguments[0]!==void 0){if(m(arguments[0])===!0){if(typeof arguments[0]=="number"&&arguments[1]===void 0)throw new SyntaxError("Value could not be converted");f.elementScroll.call(this,arguments[0].left!==void 0?~~arguments[0].left:typeof arguments[0]!="object"?~~arguments[0]:this.scrollLeft,arguments[0].top!==void 0?~~arguments[0].top:arguments[1]!==void 0?~~arguments[1]:this.scrollTop);return}var t=arguments[0].left,l=arguments[0].top;h.call(this,this,typeof t=="undefined"?this.scrollLeft:~~t,typeof l=="undefined"?this.scrollTop:~~l)}},r.prototype.scrollBy=function(){if(arguments[0]!==void 0){if(m(arguments[0])===!0){f.elementScroll.call(this,arguments[0].left!==void 0?~~arguments[0].left+this.scrollLeft:~~arguments[0]+this.scrollLeft,arguments[0].top!==void 0?~~arguments[0].top+this.scrollTop:~~arguments[1]+this.scrollTop);return}this.scroll({left:~~arguments[0].left+this.scrollLeft,top:~~arguments[0].top+this.scrollTop,behavior:arguments[0].behavior})}},r.prototype.scrollIntoView=function(){if(m(arguments[0])===!0){f.scrollIntoView.call(this,arguments[0]===void 0?!0:arguments[0]);return}var t=$(this),l=t.getBoundingClientRect(),s=this.getBoundingClientRect();t!==o.body?(h.call(this,t,t.scrollLeft+s.left-l.left,t.scrollTop+s.top-l.top),e.getComputedStyle(t).position!=="fixed"&&e.scrollBy({left:l.left,top:l.top,behavior:"smooth"})):e.scrollBy({left:s.left,top:s.top,behavior:"smooth"})}}typeof E=="object"&&typeof w!="undefined"?w.exports={polyfill:n}:n()})()});function j(n){n.magic("range",()=>function(e,o,r=1){typeof o=="undefined"&&(o=e,e=e?1:0);let i=e>o;i&&([e,o]=[o,e]);let f=Array.from({length:(o-e)/r+1},(u,c)=>e+c*r);return i?f.reverse():f})}var Y=U(I());function X(n){Y.default.polyfill(),n.magic("scroll",()=>function(e,o={}){let r=e,i=o.offset?parseInt(o.offset,10):0;if(delete o.offset,typeof e=="string"&&/^[0-9]+?/g.test(e)&&(e=parseInt(e,10)),typeof e=="string"&&(e=document.querySelector(e)),e instanceof Element&&(e=Math.floor(e.getBoundingClientRect().top+window.pageYOffset)),Number.isInteger(e)&&(e={top:e-i,behavior:"smooth"}),typeof e!="object")throw Error("Unsupported $scroll target: ",r);Object.assign(e,o),window.scroll(e)})}function B(n){let e=(o,r)=>{if(r[0].length<=o.length)return o;let i="\u2026";return typeof r[2]!="undefined"&&(i=r[2]),Object.prototype.hasOwnProperty.call(r[1],"ellipsis")&&(i=r[1].ellipsis),o+i};n.magic("truncate",()=>function(...o){return typeof o[0]!="string"||!o[1]?o[0]:typeof o[1]!="object"?e(o[0].slice(0,o[1]),o):Object.prototype.hasOwnProperty.call(o[1],"words")&&o[1].words?e(o[0].split(" ").splice(0,o[1].words).join(" "),o):Object.prototype.hasOwnProperty.call(o[1],"characters")&&o[1].characters?e(o[0].slice(0,o[1].characters),o):o[0]})}function L(n){n.magic("dbg",e=>function(...o){let r=o.map(i=>n.raw(i));console.log(...r)})}function x(n){let e=n.reactive({screensize:window.innerWidth}),o={xs:0,sm:640,md:768,lg:1024,xl:1280,"2xl":1536},r=window.AlpineMagicHelpersConfig&&window.AlpineMagicHelpersConfig.breakpoints?window.AlpineMagicHelpersConfig.breakpoints:o,i;window.addEventListener("resize",()=>{clearTimeout(i),i=setTimeout(()=>{e.screensize=window.innerWidth},150)}),n.magic("screen",f=>u=>{let c=e.screensize;if(Number.isInteger(u))return u<=c;if(r[u]===void 0)throw Error("Undefined $screen property: "+u+". Supported properties: "+Object.keys(r).join(", "));return r[u]<=c})}function P(n){n.magic("interval",()=>function(...e){if(typeof e[0]!="function")return e[0];let o=e[1],r=0,i=!1;typeof e[1]=="object"&&(Object.prototype.hasOwnProperty.call(e[1],"timer")&&(o=e[1].timer),Object.prototype.hasOwnProperty.call(e[1],"delay")&&(r=e[1].delay),Object.prototype.hasOwnProperty.call(e[1],"forceInterval")&&(i=e[1].forceInterval));let f=null,u=!0,c=()=>{let g=u?o+r:o;u=!1,f=setTimeout(()=>{e[0].call(this),i?c():requestAnimationFrame(c)},g)};n.effect(()=>{this.autoIntervalTest==null||this.autoIntervalTest?i?c():requestAnimationFrame(c):clearTimeout(f)})})}function C(n){j(n),X(n),B(n),L(n),x(n),P(n)}document.addEventListener("alpine:initializing",()=>{C(window.Alpine)});})(); 2 | -------------------------------------------------------------------------------- /exo/tinychat/static/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js: -------------------------------------------------------------------------------- 1 | (()=>{function o(e){e.directive("intersect",e.skipDuringClone((t,{value:i,expression:l,modifiers:n},{evaluateLater:r,cleanup:c})=>{let s=r(l),a={rootMargin:x(n),threshold:f(n)},u=new IntersectionObserver(d=>{d.forEach(h=>{h.isIntersecting!==(i==="leave")&&(s(),n.includes("once")&&u.disconnect())})},a);u.observe(t),c(()=>{u.disconnect()})}))}function f(e){if(e.includes("full"))return .99;if(e.includes("half"))return .5;if(!e.includes("threshold"))return 0;let t=e[e.indexOf("threshold")+1];return t==="100"?1:t==="0"?0:Number(`.${t}`)}function p(e){let t=e.match(/^(-?[0-9]+)(px|%)?$/);return t?t[1]+(t[2]||"px"):void 0}function x(e){let t="margin",i="0px 0px 0px 0px",l=e.indexOf(t);if(l===-1)return i;let n=[];for(let r=1;r<5;r++)n.push(p(e[l+r]||""));return n=n.filter(r=>r!==void 0),n.length?n.join(" ").trim():i}document.addEventListener("alpine:init",()=>{window.Alpine.plugin(o)});})(); 2 | -------------------------------------------------------------------------------- /exo/tinychat/static/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css: -------------------------------------------------------------------------------- 1 | /*! 2 | Pure v3.0.0 3 | Copyright 2013 Yahoo! 4 | Licensed under the BSD License. 5 | https://github.com/pure-css/pure/blob/master/LICENSE 6 | */ 7 | /*! 8 | normalize.css v | MIT License | https://necolas.github.io/normalize.css/ 9 | Copyright (c) Nicolas Gallagher and Jonathan Neal 10 | */ 11 | /*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */html{line-height:1.15;-webkit-text-size-adjust:100%}body{margin:0}main{display:block}h1{font-size:2em;margin:.67em 0}hr{box-sizing:content-box;height:0;overflow:visible}pre{font-family:monospace,monospace;font-size:1em}a{background-color:transparent}abbr[title]{border-bottom:none;text-decoration:underline;-webkit-text-decoration:underline dotted;text-decoration:underline dotted}b,strong{font-weight:bolder}code,kbd,samp{font-family:monospace,monospace;font-size:1em}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sub{bottom:-.25em}sup{top:-.5em}img{border-style:none}button,input,optgroup,select,textarea{font-family:inherit;font-size:100%;line-height:1.15;margin:0}button,input{overflow:visible}button,select{text-transform:none}[type=button],[type=reset],[type=submit],button{-webkit-appearance:button}[type=button]::-moz-focus-inner,[type=reset]::-moz-focus-inner,[type=submit]::-moz-focus-inner,button::-moz-focus-inner{border-style:none;padding:0}[type=button]:-moz-focusring,[type=reset]:-moz-focusring,[type=submit]:-moz-focusring,button:-moz-focusring{outline:1px dotted ButtonText}fieldset{padding:.35em .75em .625em}legend{box-sizing:border-box;color:inherit;display:table;max-width:100%;padding:0;white-space:normal}progress{vertical-align:baseline}textarea{overflow:auto}[type=checkbox],[type=radio]{box-sizing:border-box;padding:0}[type=number]::-webkit-inner-spin-button,[type=number]::-webkit-outer-spin-button{height:auto}[type=search]{-webkit-appearance:textfield;outline-offset:-2px}[type=search]::-webkit-search-decoration{-webkit-appearance:none}::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}details{display:block}summary{display:list-item}template{display:none}[hidden]{display:none}html{font-family:sans-serif}.hidden,[hidden]{display:none!important}.pure-img{max-width:100%;height:auto;display:block} -------------------------------------------------------------------------------- /exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.ttf -------------------------------------------------------------------------------- /exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-brands-400.woff2 -------------------------------------------------------------------------------- /exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.ttf -------------------------------------------------------------------------------- /exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-regular-400.woff2 -------------------------------------------------------------------------------- /exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.ttf -------------------------------------------------------------------------------- /exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-solid-900.woff2 -------------------------------------------------------------------------------- /exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.ttf -------------------------------------------------------------------------------- /exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/tinychat/static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/webfonts/fa-v4compatibility.woff2 -------------------------------------------------------------------------------- /exo/tinychat/static/fonts.googleapis.com/css2: -------------------------------------------------------------------------------- 1 | @font-face { 2 | font-family: 'Megrim'; 3 | font-style: normal; 4 | font-weight: 400; 5 | font-display: swap; 6 | src: url(https://fonts.gstatic.com/s/megrim/v16/46kulbz5WjvLqJZlbQ.ttf) format('truetype'); 7 | } 8 | -------------------------------------------------------------------------------- /exo/tinychat/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css: -------------------------------------------------------------------------------- 1 | pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}.hljs{background:#1e1e1e;color:#dcdcdc}.hljs-keyword,.hljs-literal,.hljs-name,.hljs-symbol{color:#569cd6}.hljs-link{color:#569cd6;text-decoration:underline}.hljs-built_in,.hljs-type{color:#4ec9b0}.hljs-class,.hljs-number{color:#b8d7a3}.hljs-meta .hljs-string,.hljs-string{color:#d69d85}.hljs-regexp,.hljs-template-tag{color:#9a5334}.hljs-formula,.hljs-function,.hljs-params,.hljs-subst,.hljs-title{color:#dcdcdc}.hljs-comment,.hljs-quote{color:#57a64a;font-style:italic}.hljs-doctag{color:#608b4e}.hljs-meta,.hljs-meta .hljs-keyword,.hljs-tag{color:#9b9b9b}.hljs-template-variable,.hljs-variable{color:#bd63c5}.hljs-attr,.hljs-attribute{color:#9cdcfe}.hljs-section{color:gold}.hljs-emphasis{font-style:italic}.hljs-strong{font-weight:700}.hljs-bullet,.hljs-selector-attr,.hljs-selector-class,.hljs-selector-id,.hljs-selector-pseudo,.hljs-selector-tag{color:#d7ba7d}.hljs-addition{background-color:#144212;display:inline-block;width:100%}.hljs-deletion{background-color:#600;display:inline-block;width:100%} -------------------------------------------------------------------------------- /exo/tinychat/static/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js: -------------------------------------------------------------------------------- 1 | !function(e){"function"==typeof define&&define.amd?define(e):e()}((function(){"use strict";var e=new Map;function t(t){var o=e.get(t);o&&o.destroy()}function o(t){var o=e.get(t);o&&o.update()}var r=null;"undefined"==typeof window?((r=function(e){return e}).destroy=function(e){return e},r.update=function(e){return e}):((r=function(t,o){return t&&Array.prototype.forEach.call(t.length?t:[t],(function(t){return function(t){if(t&&t.nodeName&&"TEXTAREA"===t.nodeName&&!e.has(t)){var o,r=null,n=window.getComputedStyle(t),i=(o=t.value,function(){s({testForHeightReduction:""===o||!t.value.startsWith(o),restoreTextAlign:null}),o=t.value}),l=function(o){t.removeEventListener("autosize:destroy",l),t.removeEventListener("autosize:update",a),t.removeEventListener("input",i),window.removeEventListener("resize",a),Object.keys(o).forEach((function(e){return t.style[e]=o[e]})),e.delete(t)}.bind(t,{height:t.style.height,resize:t.style.resize,textAlign:t.style.textAlign,overflowY:t.style.overflowY,overflowX:t.style.overflowX,wordWrap:t.style.wordWrap});t.addEventListener("autosize:destroy",l),t.addEventListener("autosize:update",a),t.addEventListener("input",i),window.addEventListener("resize",a),t.style.overflowX="hidden",t.style.wordWrap="break-word",e.set(t,{destroy:l,update:a}),a()}function s(e){var o,i,l=e.restoreTextAlign,a=void 0===l?null:l,d=e.testForHeightReduction,u=void 0===d||d,c=n.overflowY;if(0!==t.scrollHeight&&("vertical"===n.resize?t.style.resize="none":"both"===n.resize&&(t.style.resize="horizontal"),u&&(o=function(e){for(var t=[];e&&e.parentNode&&e.parentNode instanceof Element;)e.parentNode.scrollTop&&t.push([e.parentNode,e.parentNode.scrollTop]),e=e.parentNode;return function(){return t.forEach((function(e){var t=e[0],o=e[1];t.style.scrollBehavior="auto",t.scrollTop=o,t.style.scrollBehavior=null}))}}(t),t.style.height=""),i="content-box"===n.boxSizing?t.scrollHeight-(parseFloat(n.paddingTop)+parseFloat(n.paddingBottom)):t.scrollHeight+parseFloat(n.borderTopWidth)+parseFloat(n.borderBottomWidth),"none"!==n.maxHeight&&i>parseFloat(n.maxHeight)?("hidden"===n.overflowY&&(t.style.overflow="scroll"),i=parseFloat(n.maxHeight)):"hidden"!==n.overflowY&&(t.style.overflow="hidden"),t.style.height=i+"px",a&&(t.style.textAlign=a),o&&o(),r!==i&&(t.dispatchEvent(new Event("autosize:resized",{bubbles:!0})),r=i),c!==n.overflow&&!a)){var f=n.textAlign;"hidden"===n.overflow&&(t.style.textAlign="start"===f?"end":"start"),s({restoreTextAlign:f,testForHeightReduction:!0})}}function a(){s({testForHeightReduction:!0,restoreTextAlign:null})}}(t)})),t}).destroy=function(e){return e&&Array.prototype.forEach.call(e.length?e:[e],t),e},r.update=function(e){return e&&Array.prototype.forEach.call(e.length?e:[e],o),e});var n=r;document.addEventListener("alpine:init",(()=>{var e;(e=window.Alpine).directive("autosize",((e,{modifiers:t},{cleanup:o})=>{n(e);const r=Array.from(e.attributes);let i=!1;for(let{nodeName:e}of r)if("wire:model"===e||e.startsWith("wire:model.")){i=!0;break}!e.hasAttribute("wire:ignore")&&i&&e.setAttribute("wire:ignore","");const l=()=>{n.update(e)};e.addEventListener("autosize",l),o((()=>{n.destroy(e),e.removeEventListener("autosize",l)}))})),e.magic("autosize",(e=>t=>{(t||e).dispatchEvent(new Event("autosize"))}))}))})); 2 | //# sourceMappingURL=alpine-autosize.min.js.map 3 | -------------------------------------------------------------------------------- /exo/tinychat/static/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js: -------------------------------------------------------------------------------- 1 | (function (global, factory) { 2 | typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) : 3 | typeof define === 'function' && define.amd ? define(['exports'], factory) : 4 | (global = typeof globalThis !== 'undefined' ? globalThis : global || self, factory(global.markedHighlight = {})); 5 | })(this, (function (exports) { 'use strict'; 6 | 7 | function markedHighlight(options) { 8 | if (typeof options === 'function') { 9 | options = { 10 | highlight: options 11 | }; 12 | } 13 | 14 | if (!options || typeof options.highlight !== 'function') { 15 | throw new Error('Must provide highlight function'); 16 | } 17 | 18 | if (typeof options.langPrefix !== 'string') { 19 | options.langPrefix = 'language-'; 20 | } 21 | 22 | return { 23 | async: !!options.async, 24 | walkTokens(token) { 25 | if (token.type !== 'code') { 26 | return; 27 | } 28 | 29 | const lang = getLang(token.lang); 30 | 31 | if (options.async) { 32 | return Promise.resolve(options.highlight(token.text, lang, token.lang || '')).then(updateToken(token)); 33 | } 34 | 35 | const code = options.highlight(token.text, lang, token.lang || ''); 36 | if (code instanceof Promise) { 37 | throw new Error('markedHighlight is not set to async but the highlight function is async. Set the async option to true on markedHighlight to await the async highlight function.'); 38 | } 39 | updateToken(token)(code); 40 | }, 41 | useNewRenderer: true, 42 | renderer: { 43 | code({ text, lang, escaped }) { 44 | const language = getLang(lang); 45 | const classAttr = language 46 | ? ` class="${options.langPrefix}${escape(language)}"` 47 | : ''; 48 | text = text.replace(/\n$/, ''); 49 | return `
${escaped ? text : escape(text, true)}\n
`; 50 | } 51 | } 52 | }; 53 | } 54 | 55 | function getLang(lang) { 56 | return (lang || '').match(/\S*/)[0]; 57 | } 58 | 59 | function updateToken(token) { 60 | return (code) => { 61 | if (typeof code === 'string' && code !== token.text) { 62 | token.escaped = true; 63 | token.text = code; 64 | } 65 | }; 66 | } 67 | 68 | // copied from marked helpers 69 | const escapeTest = /[&<>"']/; 70 | const escapeReplace = new RegExp(escapeTest.source, 'g'); 71 | const escapeTestNoEncode = /[<>"']|&(?!(#\d{1,7}|#[Xx][a-fA-F0-9]{1,6}|\w+);)/; 72 | const escapeReplaceNoEncode = new RegExp(escapeTestNoEncode.source, 'g'); 73 | const escapeReplacements = { 74 | '&': '&', 75 | '<': '<', 76 | '>': '>', 77 | '"': '"', 78 | "'": ''' 79 | }; 80 | const getEscapeReplacement = (ch) => escapeReplacements[ch]; 81 | function escape(html, encode) { 82 | if (encode) { 83 | if (escapeTest.test(html)) { 84 | return html.replace(escapeReplace, getEscapeReplacement); 85 | } 86 | } else { 87 | if (escapeTestNoEncode.test(html)) { 88 | return html.replace(escapeReplaceNoEncode, getEscapeReplacement); 89 | } 90 | } 91 | 92 | return html; 93 | } 94 | 95 | exports.markedHighlight = markedHighlight; 96 | 97 | })); 98 | -------------------------------------------------------------------------------- /exo/tinychat/update_deps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from bs4 import BeautifulSoup 4 | from urllib.parse import urljoin, urlparse 5 | import re 6 | 7 | 8 | def download_file(url, local_path): 9 | response = requests.get(url) 10 | if response.status_code == 200: 11 | os.makedirs(os.path.dirname(local_path), exist_ok=True) 12 | with open(local_path, 'wb') as f: 13 | f.write(response.content) 14 | print(f"Downloaded: {local_path}") 15 | else: 16 | print(response.status_code) 17 | print(f"Failed to download: {url}") 18 | 19 | 20 | def update_html(html_content, base_url): 21 | soup = BeautifulSoup(html_content, 'html.parser') 22 | 23 | for tag in soup.find_all(['script', 'link']): 24 | if tag.has_attr('src'): 25 | url = tag['src'] 26 | elif tag.has_attr('href'): 27 | url = tag['href'] 28 | else: 29 | continue 30 | 31 | if url.startswith(('http://', 'https://')): 32 | full_url = url 33 | else: 34 | full_url = urljoin(base_url, url) 35 | 36 | parsed_url = urlparse(full_url) 37 | local_path = os.path.join('static', parsed_url.netloc, parsed_url.path.lstrip('/')) 38 | 39 | download_file(full_url, local_path) 40 | 41 | relative_path = os.path.relpath(local_path, '.') 42 | if tag.name == 'script': 43 | tag['src'] = "/" + relative_path 44 | elif tag.name == 'link': 45 | tag['href'] = "/" + relative_path 46 | 47 | return str(soup) 48 | 49 | 50 | # Read the HTML file 51 | with open('./index.html', 'r') as f: 52 | html_content = f.read() 53 | 54 | # Update HTML and download files 55 | # updated_html = update_html(html_content, 'https://example.com') 56 | 57 | # # Write the updated HTML 58 | # with open('./index.html', 'w') as f: 59 | # f.write(updated_html) 60 | 61 | print("HTML file updated with local paths.") 62 | 63 | # Download Font Awesome CSS and font files 64 | base_url = "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/" 65 | css_url = urljoin(base_url, "css/all.min.css") 66 | output_dir = "static/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2" 67 | 68 | # Download CSS file 69 | css_output_path = os.path.join(output_dir, "css", "all.min.css") 70 | download_file(css_url, css_output_path) 71 | 72 | # Parse CSS file for font URLs 73 | with open(css_output_path, 'r', encoding='utf-8') as f: 74 | css_content = f.read() 75 | 76 | # Extract font URLs from the CSS content 77 | font_urls = re.findall(r'url\((.*?\.(?:woff2|ttf))\)', css_content) 78 | 79 | print(f"Found {len(font_urls)} font URLs") 80 | 81 | # Download font files 82 | for font_url in font_urls: 83 | font_url = font_url.strip('"\'') 84 | if font_url.startswith('../'): 85 | font_url = font_url[3:] 86 | 87 | # Use base_url instead of urljoin to keep the version number 88 | full_url = base_url + font_url 89 | relative_path = font_url 90 | output_path = os.path.join(output_dir, relative_path) 91 | download_file(full_url, output_path) 92 | 93 | print("Download complete!") 94 | -------------------------------------------------------------------------------- /exo/topology/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/topology/__init__.py -------------------------------------------------------------------------------- /exo/topology/partitioning_strategy.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Dict 3 | from dataclasses import dataclass 4 | from .topology import Topology 5 | from exo.inference.shard import Shard 6 | from exo.topology.device_capabilities import device_capabilities 7 | import asyncio 8 | 9 | 10 | # Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1 11 | @dataclass 12 | class Partition: 13 | node_id: str 14 | start: float 15 | end: float 16 | 17 | 18 | class PartitioningStrategy(ABC): 19 | @abstractmethod 20 | def partition(self, topology: Topology) -> List[Partition]: 21 | pass 22 | 23 | 24 | def map_partitions_to_shards(partitions: List[Partition], num_layers: int, model_id: str) -> List[Shard]: 25 | shards = [] 26 | for i, partition in enumerate(partitions): 27 | start_layer = int(partition.start*num_layers) 28 | end_layer = int(partition.end*num_layers) - 1 29 | 30 | # Ensure the last partition covers up to num_layers - 1 31 | if i == len(partitions) - 1: 32 | end_layer = num_layers - 1 33 | 34 | # Ensure no empty shards 35 | if start_layer <= end_layer: 36 | shards.append(Shard(model_id, start_layer, end_layer, num_layers)) 37 | 38 | # Ensure full coverage 39 | if shards and shards[-1].end_layer < num_layers - 1: 40 | shards[-1] = Shard(model_id, shards[-1].start_layer, num_layers - 1, num_layers) 41 | 42 | return shards 43 | -------------------------------------------------------------------------------- /exo/topology/ring_memory_weighted_partitioning_strategy.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from .partitioning_strategy import PartitioningStrategy 3 | from .topology import Topology 4 | from .partitioning_strategy import Partition 5 | 6 | 7 | class RingMemoryWeightedPartitioningStrategy(PartitioningStrategy): 8 | def partition(self, topology: Topology) -> List[Partition]: 9 | nodes = list(topology.all_nodes()) 10 | nodes.sort(key=lambda x: (x[1].memory, x[0]), reverse=True) 11 | total_memory = sum(node[1].memory for node in nodes) 12 | partitions = [] 13 | start = 0 14 | for node in nodes: 15 | end = round(start + (node[1].memory/total_memory), 5) 16 | partitions.append(Partition(node[0], start, end)) 17 | start = end 18 | return partitions 19 | -------------------------------------------------------------------------------- /exo/topology/test_device_capabilities.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import patch 3 | from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS, device_capabilities 4 | 5 | 6 | @pytest.mark.asyncio 7 | @patch("subprocess.check_output") 8 | async def test_mac_device_capabilities_pro(mock_check_output): 9 | # Mock the subprocess output 10 | mock_check_output.return_value = b""" 11 | Hardware: 12 | 13 | Hardware Overview: 14 | 15 | Model Name: MacBook Pro 16 | Model Identifier: Mac15,9 17 | Model Number: Z1CM000EFB/A 18 | Chip: Apple M3 Max 19 | Total Number of Cores: 16 (12 performance and 4 efficiency) 20 | Memory: 128 GB 21 | System Firmware Version: 10000.000.0 22 | OS Loader Version: 10000.000.0 23 | Serial Number (system): XXXXXXXXXX 24 | Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX 25 | Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX 26 | Activation Lock Status: Enabled 27 | """ 28 | 29 | # Call the function 30 | result = await mac_device_capabilities() 31 | 32 | # Check the results 33 | assert isinstance(result, DeviceCapabilities) 34 | assert result.model == "MacBook Pro" 35 | assert result.chip == "Apple M3 Max" 36 | assert result.memory == 131072 # 128 GB in MB 37 | assert str(result) == "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS" 38 | 39 | 40 | @pytest.mark.asyncio 41 | @patch("subprocess.check_output") 42 | async def test_mac_device_capabilities_air(mock_check_output): 43 | # Mock the subprocess output 44 | mock_check_output.return_value = b""" 45 | Hardware: 46 | 47 | Hardware Overview: 48 | 49 | Model Name: MacBook Air 50 | Model Identifier: Mac14,2 51 | Model Number: MLY33B/A 52 | Chip: Apple M2 53 | Total Number of Cores: 8 (4 performance and 4 efficiency) 54 | Memory: 8 GB 55 | System Firmware Version: 10000.00.0 56 | OS Loader Version: 10000.00.0 57 | Serial Number (system): XXXXXXXXXX 58 | Hardware UUID: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX 59 | Provisioning UDID: XXXXXXXX-XXXXXXXXXXXXXXXX 60 | Activation Lock Status: Disabled 61 | """ 62 | 63 | # Call the function 64 | result = await mac_device_capabilities() 65 | 66 | # Check the results 67 | assert isinstance(result, DeviceCapabilities) 68 | assert result.model == "MacBook Air" 69 | assert result.chip == "Apple M2" 70 | assert result.memory == 8192 # 8 GB in MB 71 | 72 | 73 | @pytest.mark.skip(reason="Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB") 74 | @pytest.mark.asyncio 75 | async def test_mac_device_capabilities_real(): 76 | # Call the function without mocking 77 | result = await mac_device_capabilities() 78 | 79 | # Check the results 80 | assert isinstance(result, DeviceCapabilities) 81 | assert result.model == "MacBook Pro" 82 | assert result.chip == "Apple M3 Max" 83 | assert result.memory == 131072 # 128 GB in MB 84 | assert result.flops == DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS) 85 | assert str(result) == "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS" 86 | 87 | 88 | @pytest.mark.asyncio 89 | async def test_device_capabilities(): 90 | caps = await device_capabilities() 91 | assert caps.model != "" 92 | assert caps.chip != "" 93 | assert caps.memory > 0 94 | assert caps.flops is not None 95 | -------------------------------------------------------------------------------- /exo/topology/test_map_partitions.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from typing import List 3 | from exo.topology.partitioning_strategy import Partition, map_partitions_to_shards 4 | from exo.inference.shard import Shard 5 | 6 | 7 | class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase): 8 | def test_map_partitions_to_shards(self): 9 | partitions = [ 10 | Partition("node1", 0.0, 0.42857), 11 | Partition("node2", 0.42857, 0.71428), 12 | Partition("node3", 0.71428, 0.99999), 13 | ] 14 | shards = map_partitions_to_shards(partitions, 32, "model") 15 | self.assertEqual( 16 | shards, 17 | [ 18 | Shard("model", 0, 12, 32), 19 | Shard("model", 13, 21, 32), 20 | Shard("model", 22, 31, 32), 21 | ], 22 | ) 23 | 24 | partitions = [ 25 | Partition("node1", 0.0, 0.1), 26 | Partition("node2", 0.1, 0.2), 27 | Partition("node3", 0.2, 1.0), 28 | ] 29 | shards = map_partitions_to_shards(partitions, 32, "model") 30 | self.assertEqual( 31 | shards, 32 | [ 33 | Shard("model", 0, 2, 32), 34 | Shard("model", 3, 5, 32), 35 | Shard("model", 6, 31, 32), 36 | ], 37 | ) 38 | 39 | partitions = [ 40 | Partition("node1", 0.0, 1.0), 41 | ] 42 | shards = map_partitions_to_shards(partitions, 32, "model") 43 | self.assertEqual( 44 | shards, 45 | [ 46 | Shard("model", 0, 31, 32), 47 | ], 48 | ) 49 | 50 | partitions = [] 51 | shards = map_partitions_to_shards(partitions, 32, "model") 52 | self.assertEqual(shards, []) 53 | 54 | def test_broken_map_partitions_to_shards(self): 55 | # this was an old broken implementation that sometimes had rounding errors! 56 | def _broken_map_partitions_to_shards(partitions: List[Partition], num_layers, model_id: str): 57 | shards = [] 58 | for i, partition in enumerate(partitions): 59 | start_layer = int(partition.start*num_layers) 60 | end_layer = int(partition.end*num_layers) - 1 61 | shards.append(Shard(model_id, start_layer, end_layer, num_layers)) 62 | return shards 63 | 64 | partitions = [ 65 | Partition("node1", 0.0, 0.42857), 66 | Partition("node2", 0.42857, 0.71428), 67 | Partition("node3", 0.71428, 0.99999), 68 | ] 69 | shards = _broken_map_partitions_to_shards(partitions, 32, "model") 70 | self.assertEqual( 71 | shards, 72 | [ 73 | Shard("model", 0, 12, 32), 74 | Shard("model", 13, 21, 32), 75 | Shard("model", 22, 30, 32), 76 | ], 77 | ) 78 | 79 | 80 | if __name__ == "__main__": 81 | unittest.main() 82 | -------------------------------------------------------------------------------- /exo/topology/test_ring_memory_weighted_partitioning_strategy.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy 3 | from exo.topology.topology import Topology 4 | from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops 5 | from exo.topology.partitioning_strategy import Partition 6 | 7 | 8 | class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase): 9 | def test_partition(self): 10 | # triangle 11 | # node1 -> node2 -> node3 -> node1 12 | topology = Topology() 13 | topology.update_node( 14 | "node1", 15 | DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)), 16 | ) 17 | topology.update_node( 18 | "node2", 19 | DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)), 20 | ) 21 | topology.update_node( 22 | "node3", 23 | DeviceCapabilities(model="test3", chip="test3", memory=6000, flops=DeviceFlops(fp32=0, fp16=0, int8=0)), 24 | ) 25 | topology.add_edge("node1", "node2") 26 | topology.add_edge("node2", "node3") 27 | topology.add_edge("node3", "node1") 28 | topology.add_edge("node1", "node3") 29 | 30 | strategy = RingMemoryWeightedPartitioningStrategy() 31 | partitions = strategy.partition(topology) 32 | 33 | self.assertEqual(len(partitions), 3) 34 | self.assertEqual( 35 | partitions, 36 | [ 37 | Partition("node3", 0.0, 0.6), 38 | Partition("node1", 0.6, 0.9), 39 | Partition("node2", 0.9, 1.0), 40 | ], 41 | ) 42 | 43 | def test_partition_rounding(self): 44 | # triangle 45 | # node1 -> node2 -> node3 -> node1 46 | topology = Topology() 47 | topology.update_node( 48 | "node1", 49 | DeviceCapabilities( 50 | model="MacBook Pro", 51 | chip="test1", 52 | memory=128*1024*1024*1024, 53 | flops=DeviceFlops(fp32=0, fp16=0, int8=0), 54 | ), 55 | ) 56 | topology.update_node( 57 | "node2", 58 | DeviceCapabilities( 59 | model="Mac Studio", 60 | chip="test2", 61 | memory=192*1024*1024*1024, 62 | flops=DeviceFlops(fp32=0, fp16=0, int8=0), 63 | ), 64 | ) 65 | topology.update_node( 66 | "node3", 67 | DeviceCapabilities( 68 | model="MacBook Pro", 69 | chip="test3", 70 | memory=128*1024*1024*1024, 71 | flops=DeviceFlops(fp32=0, fp16=0, int8=0), 72 | ), 73 | ) 74 | 75 | strategy = RingMemoryWeightedPartitioningStrategy() 76 | partitions = strategy.partition(topology) 77 | 78 | self.assertEqual(len(partitions), 3) 79 | self.assertEqual( 80 | partitions, 81 | [ 82 | Partition("node3", 0.0, 0.42857), 83 | Partition("node1", 0.6, 0.9), 84 | Partition("node2", 0.9, 1.0), 85 | ], 86 | ) 87 | 88 | 89 | if __name__ == "__main__": 90 | unittest.main() 91 | -------------------------------------------------------------------------------- /exo/topology/topology.py: -------------------------------------------------------------------------------- 1 | from .device_capabilities import DeviceCapabilities 2 | from typing import Dict, Set, Optional 3 | from dataclasses import dataclass 4 | 5 | @dataclass 6 | class PeerConnection: 7 | from_id: str 8 | to_id: str 9 | description: Optional[str] = None 10 | 11 | def __hash__(self): 12 | # Use both from_id and to_id for uniqueness in sets 13 | return hash((self.from_id, self.to_id)) 14 | 15 | def __eq__(self, other): 16 | if not isinstance(other, PeerConnection): 17 | return False 18 | # Compare both from_id and to_id for equality 19 | return self.from_id == other.from_id and self.to_id == other.to_id 20 | 21 | class Topology: 22 | def __init__(self): 23 | self.nodes: Dict[str, DeviceCapabilities] = {} 24 | self.peer_graph: Dict[str, Set[PeerConnection]] = {} 25 | self.active_node_id: Optional[str] = None 26 | 27 | def update_node(self, node_id: str, device_capabilities: DeviceCapabilities): 28 | self.nodes[node_id] = device_capabilities 29 | 30 | def get_node(self, node_id: str) -> DeviceCapabilities: 31 | return self.nodes.get(node_id) 32 | 33 | def all_nodes(self): 34 | return self.nodes.items() 35 | 36 | def add_edge(self, from_id: str, to_id: str, description: Optional[str] = None): 37 | if from_id not in self.peer_graph: 38 | self.peer_graph[from_id] = set() 39 | conn = PeerConnection(from_id, to_id, description) 40 | self.peer_graph[from_id].add(conn) 41 | 42 | def merge(self, peer_node_id: str, other: "Topology"): 43 | for node_id, capabilities in other.nodes.items(): 44 | if node_id != peer_node_id: continue 45 | self.update_node(node_id, capabilities) 46 | for node_id, connections in other.peer_graph.items(): 47 | for conn in connections: 48 | if conn.from_id != peer_node_id: continue 49 | self.add_edge(conn.from_id, conn.to_id, conn.description) 50 | 51 | def __str__(self): 52 | nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items()) 53 | edges_str = ", ".join(f"{node}: {[f'{c.to_id}({c.description})' for c in conns]}" 54 | for node, conns in self.peer_graph.items()) 55 | return f"Topology(Nodes: {{{nodes_str}}}, Edges: {{{edges_str}}})" 56 | 57 | def to_json(self): 58 | return { 59 | "nodes": { 60 | node_id: capabilities.to_dict() 61 | for node_id, capabilities in self.nodes.items() 62 | }, 63 | "peer_graph": { 64 | node_id: [ 65 | { 66 | "from_id": conn.from_id, 67 | "to_id": conn.to_id, 68 | "description": conn.description 69 | } 70 | for conn in connections 71 | ] 72 | for node_id, connections in self.peer_graph.items() 73 | }, 74 | "active_node_id": self.active_node_id 75 | } 76 | -------------------------------------------------------------------------------- /exo/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/train/__init__.py -------------------------------------------------------------------------------- /exo/train/dataset.py: -------------------------------------------------------------------------------- 1 | #from https://github.com/ml-explore/mlx-examples 2 | from pathlib import Path 3 | import numpy as np 4 | import json 5 | from functools import partial, reduce 6 | def compose(*funcs): 7 | return reduce(lambda f, g: lambda x: f(g(x)), funcs, lambda x : x) 8 | 9 | def batch_with_lengths(tokens, maxlen = None): 10 | lengths = [len(x) for x in tokens] 11 | batch_size = len(lengths) 12 | if maxlen is None: 13 | maxlen = max(lengths) 14 | else: 15 | lengths = [min(maxlen, l) for l in lengths] 16 | 17 | # Pad to the max length 18 | batch_arr = np.zeros((batch_size, maxlen), np.int32) 19 | 20 | for j in range(batch_size): 21 | batch_arr[j, : lengths[j]] = tokens[j] 22 | batch = np.array(batch_arr) 23 | return batch[:, :-1], batch[:, 1:], np.array(lengths) 24 | 25 | def batch_chunk(batch_size): 26 | return lambda d, i: d[i:i + batch_size] 27 | 28 | 29 | def iterate_batches(dset, batch_size, train=False, uniform_length=True): 30 | # Shuffle indices 31 | make_batch = lambda b: batch_with_lengths(b, maxlen=dset._maxlen if uniform_length else None) 32 | chunk = batch_chunk(batch_size) 33 | while True: 34 | indices = np.arange(len(dset)) 35 | if train: 36 | indices = np.random.permutation(indices) 37 | batch = compose(make_batch, lambda i: [dset[k] for k in i], partial(chunk, indices)) 38 | 39 | # Collect batches from dataset 40 | for i in range(0, len(indices) - batch_size + 1, batch_size): 41 | yield batch(i) 42 | 43 | if not train: 44 | break 45 | 46 | class Dataset: 47 | def __init__(self, path: Path, preprocess=lambda item: item, loadline=json.loads, metrics={}): 48 | if not path.exists(): 49 | self._data = None 50 | else: 51 | self.preprocess = preprocess 52 | with open(path, "r") as fid: 53 | self._data = [loadline(l) for l in fid] 54 | self._maxlen = max([len(preprocess(x)) for x in self._data]) 55 | # Check if any sequence is longer than 2048 tokens 56 | if self._maxlen > 2048: 57 | print("You've got sequences with over 2048 tokens in here! Split your data fool!") 58 | 59 | 60 | def __getitem__(self, idx: int): 61 | return self.preprocess(self._data[idx]) 62 | 63 | def __len__(self): 64 | return len(self._data) 65 | 66 | 67 | def load_dataset(data_path: str, preprocess=lambda i: i, loadline=json.loads): 68 | def load_and_check(name): 69 | dataset_path = Path(data_path) / f"{name}.jsonl" 70 | try: 71 | return Dataset(dataset_path, preprocess=preprocess, loadline=loadline) 72 | except Exception as e: 73 | print(f"Unable to build dataset {dataset_path} ({e})") 74 | raise 75 | 76 | names = ("train", "valid", "test") 77 | train, valid, test = (load_and_check(n) for n in names) 78 | 79 | return train, valid, test 80 | 81 | -------------------------------------------------------------------------------- /exo/viz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/exo-explore/exo/e4238f9ef369037252c7542e40ea1a8a625afba7/exo/viz/__init__.py -------------------------------------------------------------------------------- /exo/viz/test_topology_viz.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import unittest 3 | from datetime import timedelta 4 | from exo.viz.topology_viz import TopologyViz 5 | from exo.topology.topology import Topology 6 | from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops 7 | from exo.topology.partitioning_strategy import Partition 8 | from exo.download.download_progress import RepoProgressEvent 9 | 10 | 11 | def create_hf_repo_progress_event( 12 | completed_files: int = 5, 13 | total_files: int = 10, 14 | downloaded_bytes: int = 500000000, 15 | downloaded_bytes_this_session: int = 250000000, 16 | total_bytes: int = 1000000000, 17 | overall_speed: int = 5000000, 18 | overall_eta: timedelta = timedelta(seconds=100), 19 | file_progress: dict = None, 20 | status: str = "in_progress" 21 | ) -> RepoProgressEvent: 22 | if file_progress is None: 23 | file_progress = { 24 | "file1.bin": 25 | RepoFileProgressEvent( 26 | repo_id="repo_id", 27 | repo_revision="repo_revision", 28 | file_path="file1.bin", 29 | downloaded=100000000, 30 | downloaded_this_session=50000000, 31 | total=200000000, 32 | speed=1000000, 33 | eta=timedelta(seconds=100), 34 | status="in_progress" 35 | ), "file2.bin": 36 | RepoFileProgressEvent( 37 | repo_id="repo_id", 38 | repo_revision="repo_revision", 39 | file_path="file2.bin", 40 | downloaded=200000000, 41 | downloaded_this_session=100000000, 42 | total=200000000, 43 | speed=2000000, 44 | eta=timedelta(seconds=0), 45 | status="complete" 46 | ) 47 | } 48 | 49 | return RepoProgressEvent( 50 | repo_id="repo_id", 51 | repo_revision="repo_revision", 52 | completed_files=completed_files, 53 | total_files=total_files, 54 | downloaded_bytes=downloaded_bytes, 55 | downloaded_bytes_this_session=downloaded_bytes_this_session, 56 | total_bytes=total_bytes, 57 | overall_speed=overall_speed, 58 | overall_eta=overall_eta, 59 | file_progress=file_progress, 60 | status=status 61 | ) 62 | 63 | 64 | class TestNodeViz(unittest.IsolatedAsyncioTestCase): 65 | async def asyncSetUp(self): 66 | self.topology = Topology() 67 | self.topology.update_node( 68 | "node1", 69 | DeviceCapabilities(model="ModelA", chip="ChipA", memory=8*1024, flops=DeviceFlops(fp32=1.0, fp16=2.0, int8=4.0)), 70 | ) 71 | self.topology.update_node( 72 | "node2", 73 | DeviceCapabilities(model="ModelB", chip="ChipB", memory=16*1024, flops=DeviceFlops(fp32=2.0, fp16=4.0, int8=8.0)), 74 | ) 75 | self.topology.update_node( 76 | "node3", 77 | DeviceCapabilities(model="ModelC", chip="ChipC", memory=32*1024, flops=DeviceFlops(fp32=4.0, fp16=8.0, int8=16.0)), 78 | ) 79 | self.topology.update_node( 80 | "node4", 81 | DeviceCapabilities(model="ModelD", chip="ChipD", memory=64*1024, flops=DeviceFlops(fp32=8.0, fp16=16.0, int8=32.0)), 82 | ) 83 | 84 | self.top_viz = TopologyViz() 85 | await asyncio.sleep(2) # Simulate running for a short time 86 | 87 | async def test_layout_generation(self): 88 | # self.top_viz._generate_layout() 89 | self.top_viz.refresh() 90 | import time 91 | 92 | time.sleep(2) 93 | self.top_viz.update_visualization( 94 | self.topology, 95 | [ 96 | Partition("node1", 0, 0.2), 97 | Partition("node4", 0.2, 0.4), 98 | Partition("node2", 0.4, 0.8), 99 | Partition("node3", 0.8, 0.9), 100 | ], 101 | "node1", 102 | { 103 | "node1": create_hf_repo_progress_event(), 104 | "node2": create_hf_repo_progress_event(), 105 | "node3": create_hf_repo_progress_event(), 106 | "node4": create_hf_repo_progress_event(), 107 | }, 108 | ) 109 | time.sleep(2) 110 | self.topology.active_node_id = "node3" 111 | self.top_viz.update_visualization( 112 | self.topology, 113 | [ 114 | Partition("node1", 0, 0.3), 115 | Partition("node5", 0.3, 0.5), 116 | Partition("node2", 0.5, 0.7), 117 | Partition("node4", 0.7, 0.9), 118 | ], 119 | "node5", 120 | { 121 | "node1": create_hf_repo_progress_event(), 122 | "node5": create_hf_repo_progress_event(), 123 | }, 124 | ) 125 | time.sleep(2) 126 | 127 | 128 | if __name__ == "__main__": 129 | unittest.main() 130 | -------------------------------------------------------------------------------- /extra/dashboard/requirements.txt: -------------------------------------------------------------------------------- 1 | plotly 2 | pandas 3 | requests 4 | aiohttp 5 | pygame -------------------------------------------------------------------------------- /extra/dashboard/sounds/gta5_wasted.mp3: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fb3fb66dd02827fbff86ef1ce3bc6438371c823aed7d4c3803ed522f008e4947 3 | size 206399 4 | -------------------------------------------------------------------------------- /extra/dashboard/sounds/pokemon_evolve.mp3: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d99cc9bdab4a4639d50f439b424547000e7c79f195b5b121734ad4ead435911c 3 | size 633345 4 | -------------------------------------------------------------------------------- /extra/pipsize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib.metadata 3 | import importlib.util 4 | import json 5 | import sys 6 | 7 | 8 | def calc_container(path): 9 | """Calculate total size of a directory or file.""" 10 | if os.path.isfile(path): 11 | try: 12 | return os.path.getsize(path) 13 | except (OSError, FileNotFoundError): 14 | return 0 15 | 16 | total_size = 0 17 | for dirpath, dirnames, filenames in os.walk(path): 18 | for f in filenames: 19 | fp = os.path.join(dirpath, f) 20 | try: 21 | total_size += os.path.getsize(fp) 22 | except (OSError, FileNotFoundError): 23 | continue 24 | return total_size 25 | 26 | 27 | def get_package_location(package_name): 28 | """Get the actual location of a package's files.""" 29 | try: 30 | spec = importlib.util.find_spec(package_name) 31 | if spec is None: 32 | return None 33 | 34 | if spec.submodule_search_locations: 35 | # Return the first location for namespace packages 36 | return spec.submodule_search_locations[0] 37 | elif spec.origin: 38 | # For single-file modules, return the file path itself 39 | return spec.origin 40 | except ImportError: 41 | return None 42 | 43 | 44 | def get_package_sizes(min_size_mb=0.1): 45 | """Get sizes of installed packages above minimum size threshold.""" 46 | package_sizes = [] 47 | 48 | # Get all installed distributions 49 | for dist in importlib.metadata.distributions(): 50 | try: 51 | package_name = dist.metadata["Name"] 52 | location = get_package_location(package_name.replace("-", "_")) 53 | 54 | if location and os.path.exists(location): 55 | size = calc_container(location) 56 | size_mb = size / (1024 * 1024) 57 | 58 | if size_mb > min_size_mb: 59 | package_sizes.append((package_name, size)) 60 | except Exception as e: 61 | print( 62 | f"Error processing {dist.metadata.get('Name', 'Unknown package')}: {e}" 63 | ) 64 | 65 | return package_sizes 66 | 67 | 68 | def main(): 69 | # Get and sort package sizes 70 | package_sizes = get_package_sizes() 71 | package_sizes.sort(key=lambda x: x[1], reverse=True) 72 | 73 | # Convert sizes to MB and prepare data 74 | table_data = [(name, size/(1024*1024)) for name, size in package_sizes] 75 | total_size = sum(size for _, size in package_sizes)/(1024*1024) 76 | 77 | # Check if --json flag is present 78 | if "--json" in sys.argv: 79 | try: 80 | output_file = sys.argv[sys.argv.index("--json") + 1] 81 | json_data = { 82 | "packages": [{ 83 | "name": name, 84 | "size_mb": round(size, 2) 85 | } for name, size in table_data], 86 | "total_size_mb": round(total_size, 2) 87 | } 88 | 89 | with open(output_file, 'w') as f: 90 | json.dump(json_data, f, indent=2) 91 | print(f"JSON data written to {output_file}") 92 | return 93 | except IndexError: 94 | print("Error: Please provide a filename after --json") 95 | sys.exit(1) 96 | except Exception as e: 97 | print(f"Error writing JSON file: {e}") 98 | sys.exit(1) 99 | 100 | # Original table output code 101 | max_name_width = max(len(name) for name, _ in table_data) 102 | max_name_width = max(max_name_width, len("Package")) 103 | 104 | print(f"\n{'Package':<{max_name_width}} | Size (MB)") 105 | print("-" * max_name_width + "-+-" + "-" * 10) 106 | 107 | for name, size in table_data: 108 | print(f"{name:<{max_name_width}} | {size:>8.2f}") 109 | 110 | print(f"\nTotal size: {total_size:.2f} MB\n") 111 | 112 | if __name__ == "__main__": 113 | main() -------------------------------------------------------------------------------- /extra/start_openwebui.sh: -------------------------------------------------------------------------------- 1 | API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}" 2 | echo "Using API_ENDPOINT=${API_ENDPOINT}" 3 | docker run -d -p 3000:8080 -e OPENAI_API_BASE_URL="${API_ENDPOINT}" -e OPENAI_API_KEY=your_secret_key -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:main 4 | -------------------------------------------------------------------------------- /format.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import subprocess 3 | import sys 4 | import os 5 | 6 | 7 | def run_yapf(target): 8 | if os.path.isfile(target): 9 | files = [target] 10 | else: 11 | files = [os.path.join(root, file) for root, _, files in os.walk(target) for file in files if file.endswith('.py')] 12 | 13 | for file in files: 14 | try: 15 | command = ["yapf", "-i", file] 16 | subprocess.run(command, check=True, capture_output=True, text=True) 17 | print(f"Formatted: {file}") 18 | except subprocess.CalledProcessError as e: 19 | print(f"Error formatting {file}: {e.stderr}") 20 | 21 | 22 | def main(): 23 | if len(sys.argv) < 2: 24 | print("Usage: python3 format.py e.g. python3 format.py ./exo") 25 | sys.exit(1) 26 | 27 | target = sys.argv[1] 28 | run_yapf(target) 29 | print("Formatting completed.") 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if command -v python3.12 &>/dev/null; then 4 | echo "Python 3.12 is installed, proceeding with python3.12..." 5 | python3.12 -m venv .venv 6 | else 7 | echo "The recommended version of Python to run exo with is Python 3.12, but $(python3 --version) is installed. Proceeding with $(python3 --version)" 8 | python3 -m venv .venv 9 | fi 10 | source .venv/bin/activate 11 | pip install -e . 12 | -------------------------------------------------------------------------------- /scripts/build_exo.py: -------------------------------------------------------------------------------- 1 | import site 2 | import subprocess 3 | import sys 4 | import os 5 | import pkgutil 6 | 7 | def run(): 8 | site_packages = site.getsitepackages()[0] 9 | base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | baseimages_dir = os.path.join(base_dir, "exo", "apputil", "baseimages") 11 | 12 | command = [ 13 | f"{sys.executable}", "-m", "nuitka", "exo/main.py", 14 | "--company-name=exolabs", 15 | "--product-name=exo", 16 | "--output-dir=dist", 17 | "--follow-imports", 18 | "--standalone", 19 | "--output-filename=exo", 20 | "--python-flag=no_site", 21 | "--onefile", 22 | f"--include-data-dir={baseimages_dir}=exo/apputil/baseimages" 23 | ] 24 | 25 | if sys.platform == "darwin": 26 | command.extend([ 27 | "--macos-app-name=exo", 28 | "--macos-app-mode=gui", 29 | "--macos-app-version=0.0.1", 30 | "--macos-signed-app-name=net.exolabs.exo", 31 | "--include-distribution-meta=mlx", 32 | "--include-module=mlx._reprlib_fix", 33 | "--include-module=mlx._os_warning", 34 | "--include-distribution-meta=huggingface_hub", 35 | "--include-module=huggingface_hub.repocard", 36 | f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=mlx/lib/mlx.metallib", 37 | f"--include-data-files={site_packages}/mlx/lib/mlx.metallib=./mlx.metallib", 38 | "--include-distribution-meta=pygments", 39 | "--nofollow-import-to=tinygrad" 40 | ]) 41 | inference_modules = [ 42 | name for _, name, _ in pkgutil.iter_modules(['exo/inference/mlx/models']) 43 | ] 44 | for module in inference_modules: 45 | command.append(f"--include-module=exo.inference.mlx.models.{module}") 46 | elif sys.platform == "win32": 47 | command.extend([ 48 | "--windows-icon-from-ico=docs/exo-logo-win.ico", 49 | "--file-version=0.0.1", 50 | "--product-version=0.0.1" 51 | ]) 52 | elif sys.platform.startswith("linux"): 53 | command.extend([ 54 | "--include-distribution-metadata=pygments", 55 | "--linux-icon=docs/exo-rounded.png" 56 | ]) 57 | try: 58 | subprocess.run(command, check=True) 59 | print("Build completed!") 60 | except subprocess.CalledProcessError as e: 61 | print(f"An error occurred: {e}") 62 | 63 | if __name__ == "__main__": 64 | run() 65 | -------------------------------------------------------------------------------- /scripts/compile_grpc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source ./install.sh 3 | pushd exo/networking/grpc 4 | python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto 5 | sed -i '' "s/import\ node_service_pb2/from . &/" node_service_pb2_grpc.py 6 | popd 7 | 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import platform 3 | import subprocess 4 | 5 | from setuptools import find_packages, setup 6 | 7 | # Base requirements for all platforms 8 | install_requires = [ 9 | "aiohttp==3.10.11", 10 | "aiohttp_cors==0.7.0", 11 | "aiofiles==24.1.0", 12 | "grpcio==1.70.0", 13 | "grpcio-tools==1.70.0", 14 | "Jinja2==3.1.4", 15 | "numpy==2.0.0", 16 | "nuitka==2.5.1", 17 | "nvidia-ml-py==12.560.30", 18 | "opencv-python==4.10.0.84", 19 | "pillow==10.4.0", 20 | "prometheus-client==0.20.0", 21 | "protobuf==5.28.1", 22 | "psutil==6.0.0", 23 | "pyamdgpuinfo==2.1.6;platform_system=='Linux'", 24 | "pydantic==2.9.2", 25 | "requests==2.32.3", 26 | "rich==13.7.1", 27 | "scapy==2.6.1", 28 | "tqdm==4.66.4", 29 | "transformers==4.46.3", 30 | "uuid==1.30", 31 | "uvloop==0.21.0", 32 | "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@ec120ce6b9ce8e4ff4b5692566a683ef240e8bc8", 33 | ] 34 | 35 | extras_require = { 36 | "formatting": ["yapf==0.40.2",], 37 | "apple_silicon": [ 38 | "mlx==0.22.0", 39 | "mlx-lm==0.21.1", 40 | ], 41 | "windows": ["pywin32==308",], 42 | "nvidia-gpu": ["nvidia-ml-py==12.560.30",], 43 | "amd-gpu": ["pyrsmi==0.2.0"], 44 | } 45 | 46 | # Check if running on macOS with Apple Silicon 47 | if sys.platform.startswith("darwin") and platform.machine() == "arm64": 48 | install_requires.extend(extras_require["apple_silicon"]) 49 | 50 | # Check if running Windows 51 | if sys.platform.startswith("win32"): 52 | install_requires.extend(extras_require["windows"]) 53 | 54 | 55 | def _add_gpu_requires(): 56 | global install_requires 57 | # Add Nvidia-GPU 58 | try: 59 | out = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], shell=True, text=True, capture_output=True, check=False) 60 | if out.returncode == 0: 61 | install_requires.extend(extras_require["nvidia-gpu"]) 62 | except subprocess.CalledProcessError: 63 | pass 64 | 65 | # Add AMD-GPU 66 | # This will mostly work only on Linux, amd/rocm-smi is not yet supported on Windows 67 | try: 68 | out = subprocess.run(['amd-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False) 69 | if out.returncode == 0: 70 | install_requires.extend(extras_require["amd-gpu"]) 71 | except: 72 | out = subprocess.run(['rocm-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False) 73 | if out.returncode == 0: 74 | install_requires.extend(extras_require["amd-gpu"]) 75 | finally: 76 | pass 77 | 78 | 79 | _add_gpu_requires() 80 | 81 | setup( 82 | name="exo", 83 | version="0.0.1", 84 | packages=find_packages(), 85 | install_requires=install_requires, 86 | extras_require=extras_require, 87 | package_data={"exo": ["tinychat/**/*"]}, 88 | entry_points={"console_scripts": ["exo = exo.main:run"]}, 89 | ) 90 | -------------------------------------------------------------------------------- /test/reconnect.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | echo "Starting node 1" 4 | DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 & 5 | PID1=$! 6 | echo "Started node 1 PID: $PID1" 7 | echo "Starting node 2" 8 | DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 > output2.log 2>&1 & 9 | PID2=$! 10 | echo "Started node 2 PID: $PID2" 11 | sleep 5 12 | kill $PID2 13 | sleep 5 14 | echo "Starting node 2 again..." 15 | DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --chatgpt-api-response-timeout 900 > output3.log 2>&1 & 16 | PID2=$! 17 | sleep 5 18 | echo "Killing nodes and ending test..." 19 | kill $PID1 20 | kill $PID2 21 | echo "Test complete." -------------------------------------------------------------------------------- /test/test_model_helpers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from exo.models import get_supported_models, model_cards 3 | from exo.inference.inference_engine import inference_engine_classes 4 | from typing import NamedTuple 5 | 6 | class TestCase(NamedTuple): 7 | name: str 8 | engine_lists: list # Will contain short names, will be mapped to class names 9 | expected_models_contains: list 10 | min_count: int | None 11 | exact_count: int | None 12 | max_count: int | None 13 | 14 | # Helper function to map short names to class names 15 | def expand_engine_lists(engine_lists): 16 | def map_engine(engine): 17 | return inference_engine_classes.get(engine, engine) # Return original name if not found 18 | 19 | return [[map_engine(engine) for engine in sublist] 20 | for sublist in engine_lists] 21 | 22 | test_cases = [ 23 | TestCase( 24 | name="single_mlx_engine", 25 | engine_lists=[["mlx"]], 26 | expected_models_contains=["llama-3.2-1b", "llama-3.1-70b", "mistral-nemo"], 27 | min_count=10, 28 | exact_count=None, 29 | max_count=None 30 | ), 31 | TestCase( 32 | name="single_tinygrad_engine", 33 | engine_lists=[["tinygrad"]], 34 | expected_models_contains=["llama-3.2-1b", "llama-3.2-3b"], 35 | min_count=5, 36 | exact_count=None, 37 | max_count=15 38 | ), 39 | TestCase( 40 | name="multiple_engines_or", 41 | engine_lists=[["mlx", "tinygrad"], ["mlx"]], 42 | expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"], 43 | min_count=10, 44 | exact_count=None, 45 | max_count=None 46 | ), 47 | TestCase( 48 | name="multiple_engines_all", 49 | engine_lists=[["mlx", "tinygrad"], ["mlx", "tinygrad"]], 50 | expected_models_contains=["llama-3.2-1b", "llama-3.2-3b", "mistral-nemo"], 51 | min_count=10, 52 | exact_count=None, 53 | max_count=None 54 | ), 55 | TestCase( 56 | name="distinct_engine_lists", 57 | engine_lists=[["mlx"], ["tinygrad"]], 58 | expected_models_contains=["llama-3.2-1b"], 59 | min_count=5, 60 | exact_count=None, 61 | max_count=15 62 | ), 63 | TestCase( 64 | name="no_engines", 65 | engine_lists=[], 66 | expected_models_contains=None, 67 | min_count=None, 68 | exact_count=len(model_cards), 69 | max_count=None 70 | ), 71 | TestCase( 72 | name="nonexistent_engine", 73 | engine_lists=[["NonexistentEngine"]], 74 | expected_models_contains=[], 75 | min_count=None, 76 | exact_count=0, 77 | max_count=None 78 | ), 79 | TestCase( 80 | name="dummy_engine", 81 | engine_lists=[["dummy"]], 82 | expected_models_contains=["dummy"], 83 | min_count=None, 84 | exact_count=1, 85 | max_count=None 86 | ), 87 | ] 88 | 89 | class TestModelHelpers(unittest.TestCase): 90 | def test_get_supported_models(self): 91 | for case in test_cases: 92 | with self.subTest(f"{case.name}_short_names"): 93 | result = get_supported_models(case.engine_lists) 94 | self._verify_results(case, result) 95 | 96 | with self.subTest(f"{case.name}_class_names"): 97 | class_name_lists = expand_engine_lists(case.engine_lists) 98 | result = get_supported_models(class_name_lists) 99 | self._verify_results(case, result) 100 | 101 | def _verify_results(self, case, result): 102 | if case.expected_models_contains: 103 | for model in case.expected_models_contains: 104 | self.assertIn(model, result) 105 | 106 | if case.min_count: 107 | self.assertGreater(len(result), case.min_count) 108 | 109 | if case.exact_count is not None: 110 | self.assertEqual(len(result), case.exact_count) 111 | 112 | # Special case for distinct lists test 113 | if case.name == "distinct_engine_lists": 114 | self.assertLess(len(result), 15) 115 | self.assertNotIn("mistral-nemo", result) 116 | 117 | if case.max_count: 118 | self.assertLess(len(result), case.max_count) 119 | 120 | if __name__ == '__main__': 121 | unittest.main() 122 | -------------------------------------------------------------------------------- /test/test_tokenizers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from transformers import AutoTokenizer, AutoProcessor 4 | from exo.models import model_cards 5 | 6 | 7 | def test_tokenizer(name, tokenizer, verbose=False): 8 | print(f"--- {name} ({tokenizer.__class__.__name__}) ---") 9 | text = "Hello! How can I assist you today? Let me know if you need help with something or just want to chat." 10 | encoded = tokenizer.encode(text) 11 | decoded = tokenizer.decode(encoded) 12 | 13 | print(f"{encoded=}") 14 | print(f"{decoded=}") 15 | 16 | reconstructed = "" 17 | for token in encoded: 18 | if verbose: 19 | print(f"{token=}") 20 | print(f"{tokenizer.decode([token])=}") 21 | reconstructed += tokenizer.decode([token]) 22 | print(f"{reconstructed=}") 23 | 24 | strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id])) 25 | assert text == strip_tokens(decoded) == strip_tokens(reconstructed) 26 | 27 | ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit", "stabilityai/stable-diffusion-2-1-base"] 28 | ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")") 29 | models = [] 30 | for model_id in model_cards: 31 | for engine_type, repo_id in model_cards[model_id].get("repo", {}).items(): 32 | if not ignore_pattern.match(repo_id): 33 | models.append(repo_id) 34 | models = list(set(models)) 35 | 36 | verbose = os.environ.get("VERBOSE", "0").lower() == "1" 37 | for m in models: 38 | # TODO: figure out why use_fast=False is giving inconsistent behaviour (no spaces decoding invididual tokens) for Mistral-Large-Instruct-2407-4bit 39 | # test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=False), verbose) 40 | if m not in ["mlx-community/DeepSeek-R1-4bit", "mlx-community/DeepSeek-R1-3bit", "mlx-community/DeepSeek-V3-4bit", "mlx-community/DeepSeek-V3-3bit"]: 41 | test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True, trust_remote_code=True), verbose) 42 | test_tokenizer(m, AutoTokenizer.from_pretrained(m, trust_remote_code=True), verbose) 43 | --------------------------------------------------------------------------------