├── .bumpversion.cfg ├── .gitattributes ├── .github └── workflows │ ├── docs.yml │ ├── main.yml │ ├── release.yml │ └── tests.yml ├── .gitignore ├── .gitmodules ├── .vscode └── settings.json ├── CHANGELOG.md ├── README.md ├── docs ├── accounts.md ├── addresses.md ├── clearing_house.md ├── clearing_house_user.md ├── css │ ├── custom.css │ └── mkdocstrings.css ├── img │ └── drift.png └── index.md ├── examples ├── authority_user-account_lookup.py ├── check_pyth_lazer.py ├── deposit_and_withdraw.py ├── driftpy-enhanced-usermap.py ├── driftpy-marketmap-details.py ├── dump_pickle.py ├── example_grpc_log_provider.py ├── fetch_all_markets.py ├── floating_maker.py ├── get_all_users_pnl.py ├── get_borrows.py ├── get_deposits.py ├── get_funding.py ├── get_vault_metrics.py ├── grpc_auction.py ├── grpc_client.py ├── high_leverage_users.py ├── if_stake.py ├── jit_maker_grpc.py ├── limit_order_grid.py ├── minimal_maker.py ├── oracle_maker.py ├── orders_subcribe.py ├── place_and_take.py ├── pmm_users.py ├── protected_order.py ├── readme.md ├── send_indicative_quotes.py ├── settle_pnl.py ├── signed_msg_subscribe.py ├── spot_market_trade.py ├── start_lp.py ├── swift_maker.py ├── swift_take_place_and_make.py ├── swift_taker.py ├── swift_taker_delegate.py ├── view.py └── ws_simple.py ├── mkdocs.yml ├── poetry.lock ├── poetry.toml ├── pyproject.toml ├── requirements.txt ├── scripts ├── bump.py ├── ci.sh ├── decode.sh ├── dlob.sh ├── generate_constants.py └── math.sh ├── setup.sh ├── src └── driftpy │ ├── __init__.py │ ├── account_subscription_config.py │ ├── accounts │ ├── __init__.py │ ├── bulk_account_loader.py │ ├── cache │ │ ├── __init__.py │ │ ├── drift_client.py │ │ └── user.py │ ├── demo │ │ ├── __init__.py │ │ ├── drift_client.py │ │ └── user.py │ ├── get_accounts.py │ ├── grpc │ │ ├── account_subscriber.py │ │ ├── drift_client.py │ │ ├── geyser_codegen │ │ │ ├── README.md │ │ │ ├── geyser_pb2.py │ │ │ ├── geyser_pb2.pyi │ │ │ ├── geyser_pb2_grpc.py │ │ │ ├── solana_storage_pb2.py │ │ │ ├── solana_storage_pb2.pyi │ │ │ ├── solana_storage_pb2_grpc.py │ │ │ └── subscribe_geyser.py │ │ ├── program_account_subscriber.py │ │ ├── user.py │ │ └── user_stats.py │ ├── oracle.py │ ├── polling │ │ ├── __init__.py │ │ ├── drift_client.py │ │ └── user.py │ ├── types.py │ └── ws │ │ ├── __init__.py │ │ ├── account_subscriber.py │ │ ├── drift_client.py │ │ ├── program_account_subscriber.py │ │ ├── user.py │ │ └── user_stats.py │ ├── address_lookup_table.py │ ├── addresses.py │ ├── admin.py │ ├── auction_subscriber │ ├── auction_subscriber.py │ ├── grpc_auction_subscriber.py │ └── types.py │ ├── constants │ ├── __init__.py │ ├── config.py │ ├── numeric_constants.py │ ├── perp_markets.py │ └── spot_markets.py │ ├── decode │ ├── pull_oracle.py │ ├── signed_msg_order.py │ ├── user.py │ ├── user_stat.py │ └── utils.py │ ├── dlob │ ├── client_types.py │ ├── dlob.py │ ├── dlob_helpers.py │ ├── dlob_node.py │ ├── dlob_subscriber.py │ ├── node_list.py │ └── orderbook_levels.py │ ├── drift_client.py │ ├── drift_user.py │ ├── drift_user_stats.py │ ├── events │ ├── __init__.py │ ├── event_list.py │ ├── event_subscriber.py │ ├── fetch_logs.py │ ├── grpc_log_provider.py │ ├── parse.py │ ├── polling_log_provider.py │ ├── sort.py │ ├── tx_event_cache.py │ ├── types.py │ └── websocket_log_provider.py │ ├── idl │ ├── __init__.py │ ├── drift.json │ ├── drift_vaults.json │ ├── pyth.json │ ├── sequence_enforcer.json │ ├── switchboard.json │ ├── switchboard_on_demand.json │ └── token_faucet.json │ ├── indicative_quotes │ ├── __init__.py │ └── indicative_quotes_sender.py │ ├── keypair.py │ ├── market_map │ ├── grpc_market_map.py │ ├── grpc_sub.py │ ├── market_map.py │ ├── market_map_config.py │ └── websocket_sub.py │ ├── math │ ├── amm.py │ ├── auction.py │ ├── conversion.py │ ├── exchange_status.py │ ├── fuel.py │ ├── funding.py │ ├── margin.py │ ├── market.py │ ├── oracles.py │ ├── orders.py │ ├── perp_position.py │ ├── repeg.py │ ├── spot_balance.py │ ├── spot_market.py │ ├── spot_position.py │ ├── user_status.py │ └── utils.py │ ├── memcmp.py │ ├── name.py │ ├── oracles │ ├── oracle_id.py │ └── strict_oracle_price.py │ ├── pickle │ └── vat.py │ ├── priority_fees │ └── priority_fee_subscriber.py │ ├── py.typed │ ├── setup │ └── helpers.py │ ├── slot │ └── slot_subscriber.py │ ├── swift │ ├── create_verify_ix.py │ ├── order_subscriber.py │ └── util.py │ ├── tx │ ├── __init__.py │ ├── fast_tx_sender.py │ ├── jito_subscriber.py │ ├── jito_tx_sender.py │ ├── standard_tx_sender.py │ └── types.py │ ├── types.py │ ├── user_map │ ├── polling_sub.py │ ├── types.py │ ├── user_map.py │ ├── user_map_config.py │ ├── userstats_map.py │ └── websocket_sub.py │ └── vaults │ ├── __init__.py │ ├── helpers.py │ └── vault_client.py ├── tests ├── __init__.py ├── ci │ ├── __init__.py │ ├── devnet.py │ └── mainnet.py ├── decode │ ├── __init__.py │ ├── decode.py │ ├── decode_stat.py │ ├── decode_strings.py │ ├── dlob_test_helpers.py │ └── stat_decode_strings.py ├── dlob │ ├── __init__.py │ └── dlob.py ├── dlob_test_constants.py ├── integration │ ├── events_parser.py │ ├── liq.py │ ├── oracle.py │ ├── prelaunch.py │ ├── swb_on_demand.py │ └── test_oracle_diff_sources.py └── math │ ├── __init__.py │ ├── amm.py │ ├── auction.py │ ├── funding.py │ ├── helpers.py │ ├── insurance.py │ ├── spot.py │ ├── spreads.py │ └── user.py └── update_idl.sh /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.8.56 3 | commit = True 4 | tag = True 5 | tag_name = {new_version} 6 | 7 | [bumpversion:file:pyproject.toml] 8 | search = version = "{current_version}" 9 | replace = version = "{new_version}" 10 | 11 | [bumpversion:file:src/driftpy/__init__.py] 12 | search = __version__ = "{current_version}" 13 | replace = __version__ = "{new_version}" 14 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | drift-core/** linguist-vendored 2 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | on: 3 | release: 4 | types: [published] 5 | jobs: 6 | docs: 7 | name: Deploy docs 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout 11 | uses: actions/checkout@v2.5.0 12 | 13 | - name: Set up Python 14 | uses: actions/setup-python@v4.3.0 15 | with: 16 | python-version: '3.10' 17 | 18 | - name: Install dependencies 19 | run: pip install mkdocs mkdocstrings-python mkdocs-material driftpy 20 | 21 | - name: Deploy docs 22 | run: mkdocs gh-deploy --force 23 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | branches: 5 | - master 6 | pull_request: 7 | branches: [master] 8 | 9 | defaults: 10 | run: 11 | shell: bash 12 | working-directory: . 13 | 14 | jobs: 15 | tests: 16 | runs-on: ubicloud 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Set up Python 3.10 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: "3.10.10" 23 | - name: Install and configure Poetry 24 | uses: snok/install-poetry@v1.3.3 25 | with: 26 | version: 1.4.2 27 | virtualenvs-create: true 28 | virtualenvs-in-project: true 29 | installer-parallel: true 30 | - name: Install dependencies 31 | run: poetry install 32 | - name: Install pytest 33 | run: poetry run pip install pytest 34 | - name: Install ruff 35 | run: poetry run pip install ruff 36 | - name: Run tests 37 | env: 38 | MAINNET_RPC_ENDPOINT: ${{ secrets.MAINNET_RPC_ENDPOINT }} 39 | DEVNET_RPC_ENDPOINT: ${{ secrets.DEVNET_RPC_ENDPOINT }} 40 | # run: poetry run ruff format --check . && poetry run bash scripts/ci.sh 41 | run: poetry run bash scripts/ci.sh 42 | 43 | bump-version: 44 | runs-on: ubicloud 45 | needs: [tests] 46 | if: github.event_name == 'push' && github.ref == 'refs/heads/master' 47 | steps: 48 | - uses: actions/checkout@v3 49 | with: 50 | fetch-depth: 0 51 | - name: Set up Python 3.10 52 | uses: actions/setup-python@v4 53 | with: 54 | python-version: "3.10.10" 55 | - name: Run version bump script 56 | run: python scripts/bump.py 57 | - name: Commit changes 58 | run: | 59 | git config --local user.email "github-actions[bot]@users.noreply.github.com" 60 | git config --local user.name "github-actions[bot]" 61 | git add pyproject.toml src/driftpy/__init__.py .bumpversion.cfg 62 | git commit -m "Bump version [skip ci]" 63 | - name: Push changes 64 | uses: ad-m/github-push-action@master 65 | with: 66 | github_token: ${{ secrets.GITHUB_TOKEN }} 67 | branch: ${{ github.ref }} 68 | 69 | release: 70 | runs-on: ubicloud 71 | needs: [bump-version] 72 | if: github.event_name == 'push' && github.ref == 'refs/heads/master' 73 | steps: 74 | - name: Checkout 75 | uses: actions/checkout@v3 76 | with: 77 | fetch-depth: 0 78 | - name: Pull Latest Changes 79 | run: | 80 | git config --local user.email "github-actions[bot]@users.noreply.github.com" 81 | git config --local user.name "github-actions[bot]" 82 | git pull 83 | - name: Set up Python 84 | uses: actions/setup-python@v4 85 | with: 86 | python-version: "3.10.10" 87 | - name: Install and configure Poetry 88 | uses: snok/install-poetry@v1.3.3 89 | with: 90 | version: 1.4.2 91 | virtualenvs-create: true 92 | virtualenvs-in-project: true 93 | installer-parallel: true 94 | - name: Build package 95 | run: poetry build 96 | - name: Publish to PyPI 97 | run: poetry publish --username=__token__ --password=${{ secrets.PYPI_TOKEN }} 98 | - name: Get version 99 | id: get_version 100 | run: echo "VERSION=$(poetry version -s)" >> $GITHUB_OUTPUT 101 | - name: Create GitHub Release 102 | env: 103 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 104 | run: | 105 | gh release create v${{ steps.get_version.outputs.VERSION }} \ 106 | --title "Release v${{ steps.get_version.outputs.VERSION }}" \ 107 | --generate-notes \ 108 | dist/*.whl dist/*.tar.gz 109 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | release: 4 | types: [published] 5 | jobs: 6 | release: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout 10 | uses: actions/checkout@v2.5.0 11 | 12 | - name: Set up Python 13 | uses: actions/setup-python@v4.3.0 14 | with: 15 | python-version: '3.10.10' 16 | #---------------------------------------------- 17 | # ----- install & configure poetry ----- 18 | #---------------------------------------------- 19 | - name: Install and configure Poetry 20 | uses: snok/install-poetry@v1.3.3 21 | with: 22 | version: 1.4.2 23 | virtualenvs-create: true 24 | virtualenvs-in-project: true 25 | installer-parallel: true 26 | - run: poetry build 27 | - run: poetry publish --username=__token__ --password=${{ secrets.PYPI_TOKEN }} 28 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | on: 3 | push: 4 | branches: [master] 5 | pull_request: 6 | branches: [master] 7 | 8 | env: 9 | solana_verion: 1.14.7 10 | anchor_version: 0.26.0 11 | 12 | jobs: 13 | tests: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Checkout repo. 17 | uses: actions/checkout@v2 18 | with: 19 | submodules: 'recursive' 20 | 21 | - name: Cache Solana Tool Suite 22 | uses: actions/cache@v2 23 | id: cache-solana 24 | with: 25 | path: | 26 | ~/.cache/solana/ 27 | ~/.local/share/solana/ 28 | key: solana-${{ runner.os }}-v0000-${{ env.solana_verion }} 29 | 30 | - name: Install Rust toolchain 31 | uses: actions-rs/toolchain@v1 32 | with: 33 | profile: minimal 34 | toolchain: nightly 35 | override: true 36 | 37 | - name: Install Solana 38 | if: steps.cache-solana.outputs.cache-hit != 'true' 39 | run: sh -c "$(curl -sSfL https://release.solana.com/v${{ env.solana_verion }}/install)" 40 | 41 | - name: Add Solana to path 42 | run: echo "/home/runner/.local/share/solana/install/active_release/bin" >> $GITHUB_PATH 43 | 44 | - uses: actions/setup-node@v2 45 | with: 46 | node-version: '17' 47 | 48 | - name: install Anchor CLI 49 | run: npm install -g @project-serum/anchor-cli 50 | 51 | - name: Generate local keypair 52 | run: yes | solana-keygen new 53 | 54 | - name: Set up Python 55 | uses: actions/setup-python@v1 56 | with: 57 | python-version: 3.9 58 | 59 | #---------------------------------------------- 60 | # ----- install & configure poetry ----- 61 | #---------------------------------------------- 62 | - name: Install and configure Poetry 63 | uses: snok/install-poetry@v1 64 | with: 65 | version: 1.1.10 66 | virtualenvs-create: true 67 | virtualenvs-in-project: true 68 | installer-parallel: true 69 | #---------------------------------------------- 70 | # load cached venv if cache exists 71 | #---------------------------------------------- 72 | - name: Load cached venv 73 | id: cached-poetry-dependencies 74 | uses: actions/cache@v2 75 | with: 76 | path: .venv 77 | key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} 78 | #---------------------------------------------- 79 | # install dependencies if cache does not exist (todo) 80 | #---------------------------------------------- 81 | - name: Install dependencies 82 | # if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 83 | run: poetry install --no-interaction --no-root 84 | #---------------------------------------------- 85 | # install your root project 86 | #---------------------------------------------- 87 | - name: Install library 88 | run: poetry install --no-interaction 89 | #---------------------------------------------- 90 | # install nox-poetry 91 | #---------------------------------------------- 92 | - name: Install nox-poetry 93 | run: pip install nox-poetry 94 | #---------------------------------------------- 95 | # run linters 96 | #---------------------------------------------- 97 | - name: Run linters 98 | run: make lint 99 | #---------------------------------------------- 100 | # run test suite 101 | #---------------------------------------------- 102 | - name: Run tests 103 | run: make test 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .tmp 2 | .DS_Store 3 | node.txt 4 | keypairs/ 5 | test-ledger/ 6 | 7 | node_modules/ 8 | yarn.lock 9 | package-lock.json 10 | package.json 11 | .prettierrc 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | docs/.DS_Store 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | cover/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | .pybuilder/ 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | # For a library or package, you might want to ignore these files since the code is 100 | # intended to run in multiple environments; otherwise, check them in: 101 | # .python-version 102 | 103 | # pipenv 104 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 105 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 106 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 107 | # install all needed dependencies. 108 | #Pipfile.lock 109 | 110 | # poetry 111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 112 | # This is especially recommended for binary packages to ensure reproducibility, and is more 113 | # commonly ignored for libraries. 114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 115 | #poetry.lock 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | .idea/ 166 | 167 | scratch 168 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "protocol-v2"] 2 | path = protocol-v2 3 | url = https://github.com/drift-labs/protocol-v2.git 4 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.formatting.provider": "black", 3 | "python.analysis.extraPaths": [ 4 | "${workspaceFolder}/src/", 5 | "./venv/bin/python" 6 | ] 7 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DriftPy 2 | 3 |
4 | 5 |
6 | 7 | DriftPy is the Python client for the [Drift](https://www.drift.trade/) protocol. 8 | It allows you to trade and fetch data from Drift using Python. 9 | 10 | **[Read the full SDK documentation here!](https://drift-labs.github.io/v2-teacher/)** 11 | 12 | ## Installation 13 | 14 | ``` 15 | pip install driftpy 16 | ``` 17 | 18 | Note: requires Python >= 3.10. 19 | 20 | 21 | ## SDK Examples 22 | 23 | - `examples/` folder includes more examples of how to use the SDK including how to provide liquidity/become an lp, stake in the insurance fund, etc. 24 | 25 | 26 | ## Note on using QuickNode 27 | 28 | If you are using QuickNode free plan, you *must* use `AccountSubscriptionConfig("demo")`, and you can only subscribe to 1 perp market and 1 spot market at a time. 29 | 30 | Non-QuickNode free RPCs (including the public mainnet-beta url) can use `cached` as well. 31 | 32 | Example setup for `AccountSubscriptionConfig("demo")`: 33 | 34 | ```python 35 | # This example will listen to perp markets 0 & 1 and spot market 0 36 | # If you are listening to any perp markets, you must listen to spot market 0 or the SDK will break 37 | 38 | perp_markets = [0, 1] 39 | spot_market_oracle_infos, perp_market_oracle_infos, spot_market_indexes = get_markets_and_oracles(perp_markets = perp_markets) 40 | 41 | oracle_infos = spot_market_oracle_infos + perp_market_oracle_infos 42 | 43 | drift_client = DriftClient( 44 | connection, 45 | wallet, 46 | "mainnet", 47 | perp_market_indexes = perp_markets, 48 | spot_market_indexes = spot_market_indexes, 49 | oracle_infos = oracle_infos, 50 | account_subscription = AccountSubscriptionConfig("demo"), 51 | ) 52 | await drift_client.subscribe() 53 | ``` 54 | If you intend to use `AccountSubscriptionConfig("demo)`, you *must* call `get_markets_and_oracles` to get the information you need. 55 | 56 | `get_markets_and_oracles` will return all the necessary `OracleInfo`s and `market_indexes` in order to use the SDK. 57 | 58 | # Development 59 | 60 | ## Setting Up Dev Env 61 | 62 | `bash setup.sh` 63 | 64 | 65 | Ensure correct python version (using pyenv is recommended): 66 | ```bash 67 | pyenv install 3.10.11 68 | pyenv global 3.10.11 69 | poetry env use $(pyenv which python) 70 | ``` 71 | 72 | Install dependencies: 73 | ```bash 74 | poetry install 75 | ``` 76 | 77 | To run tests, first ensure you have set up the RPC url, then run `pytest`: 78 | ```bash 79 | export MAINNET_RPC_ENDPOINT="" 80 | export DEVNET_RPC_ENDPOINT="https://api.devnet.solana.com" # or your own RPC 81 | 82 | poetry run pytest -v -s -x tests/ci/*.py 83 | poetry run pytest -v -s tests/math/*.py 84 | ``` 85 | -------------------------------------------------------------------------------- /docs/accounts.md: -------------------------------------------------------------------------------- 1 | # Accounts 2 | 3 | These functions are used to retrieve specific on-chain accounts (State, PerpMarket, SpotMarket, etc.) 4 | 5 | ## Example 6 | 7 | ```python 8 | drift_client= DriftClient.from_config(config, provider) 9 | 10 | # get sol market info 11 | sol_market_index = 0 12 | sol_market = await get_perp_market_account(drift_client.program, sol_market_index) 13 | print( 14 | sol_market.amm.sqrt_k, 15 | sol_market.amm.base_asset_amount_long, 16 | sol_market.amm.base_asset_amount_short, 17 | ) 18 | 19 | # get usdc spot market info 20 | usdc_spot_market_index = 0 21 | usdc_market = await get_spot_market_account(drift_client.program, usdc_spot_market_index) 22 | print( 23 | usdc.market_index, 24 | usdc.deposit_balance, 25 | usdc.borrow_balance, 26 | ) 27 | ``` 28 | 29 | :::driftpy.accounts 30 | -------------------------------------------------------------------------------- /docs/addresses.md: -------------------------------------------------------------------------------- 1 | # Addresses 2 | 3 | These functions are used to derive on-chain addresses of the accounts (publickey of the sol-market) 4 | 5 | :::driftpy.addresses -------------------------------------------------------------------------------- /docs/clearing_house.md: -------------------------------------------------------------------------------- 1 | # Drift Client 2 | 3 | This object is used to interact with the protocol (deposit, withdraw, trade, lp, etc.) 4 | 5 | ## Example 6 | 7 | ```python 8 | drift_client = DriftClient.from_config(config,provider) 9 | # open a 10 SOL long position 10 | sig = await drift_client.open_position( 11 | PositionDirection.LONG(), # long 12 | int(10 * BASE_PRECISION), # 10 in base precision 13 | 0, # sol market index 14 | ) 15 | 16 | # mint 100 LP shares on the SOL market 17 | await drift_client.add_liquidity( 18 | int(100 * AMM_RESERVE_PRECISION), 19 | 0, 20 | ) 21 | ``` 22 | 23 | ## Configuration 24 | 25 | Use the `JUPITER_URL` environment variable to set the endpoint URL for the Jupiter V6 Swap API. This allows you to switch between self-hosted, paid-hosted, or other public API endpoints such as [jupiterapi.com](https://www.jupiterapi.com/) for higher rate limits and reduced latency. For more details, see the official [self-hosted](https://station.jup.ag/docs/apis/self-hosted) and [paid-hosted](https://station.jup.ag/docs/apis/self-hosted#paid-hosted-apis) documentation. 26 | 27 | :::driftpy.drift_client 28 | -------------------------------------------------------------------------------- /docs/clearing_house_user.md: -------------------------------------------------------------------------------- 1 | # User 2 | 3 | This object is used to fetch data from the protocol and view user metrics (leverage, free collateral, etc.) 4 | 5 | ## Example 6 | 7 | ```python 8 | drift_client = DriftClient.from_config(config, provider) 9 | drift_user = User(drift_client) 10 | 11 | # inspect user's leverage 12 | leverage = await drift_user.get_leverage() 13 | print('current leverage:', leverage / 10_000) 14 | 15 | # you can also inspect other accounts information using the (authority=) flag 16 | bigz_acc = User(drift_client, authority=PublicKey('bigZ')) 17 | leverage = await bigz_acc.get_leverage() 18 | print('bigZs leverage:', leverage / 10_000) 19 | 20 | # user calls can be expensive on the rpc so we can cache them 21 | drift_user = User(drift_client, use_cache=True) 22 | await drift_user.set_cache() 23 | 24 | # works without any rpc calls (uses the cached data) 25 | upnl = await drift_user.get_unrealized_pnl(with_funding=True) 26 | print('upnl:', upnl) 27 | ``` 28 | 29 | :::driftpy.drift_user 30 | -------------------------------------------------------------------------------- /docs/css/custom.css: -------------------------------------------------------------------------------- 1 | a.external-link::after { 2 | /* \00A0 is a non-breaking space 3 | to make the mark be on the same line as the link 4 | */ 5 | content: "\00A0[↪]"; 6 | } 7 | 8 | a.internal-link::after { 9 | /* \00A0 is a non-breaking space 10 | to make the mark be on the same line as the link 11 | */ 12 | content: "\00A0↪"; 13 | } 14 | -------------------------------------------------------------------------------- /docs/css/mkdocstrings.css: -------------------------------------------------------------------------------- 1 | /* Indentation. */ 2 | div.doc-contents:not(.first) { 3 | padding-left: 25px; 4 | border-left: 4px solid rgba(230, 230, 230); 5 | margin-bottom: 80px; 6 | } 7 | -------------------------------------------------------------------------------- /docs/img/drift.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/docs/img/drift.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Drift-v2 Python SDK 2 | 3 |
4 | 5 |
6 | 7 | DriftPy is the Python SDK for [Drift-v2](https://www.drift.trade/) on Solana. 8 | It allows you to trade and fetch data from Drift using Python. 9 | 10 | ## Installation 11 | 12 | ``` 13 | pip install driftpy 14 | ``` 15 | 16 | Note: requires Python >= 3.10. 17 | 18 | ## Key Components 19 | 20 | - `DriftClient` / `drift_client.py`: Used to interact with the protocol (deposit, withdraw, trade, lp, etc.) 21 | - `DriftUser` / `drift_user.py`: Used to fetch data from the protocol and view user metrics (leverage, free collateral, etc.) 22 | - `accounts.py`: Used to retrieve specific on-chain accounts (State, PerpMarket, SpotMarket, etc.) 23 | - `addresses.py`: Used to derive on-chain addresses of the accounts (publickey of the sol-market) 24 | 25 | ## Example 26 | 27 | ```python 28 | 29 | from solana.keypair import Keypair 30 | from driftpy.drift_client import DriftClient 31 | from driftpy.drift_user import DriftUser 32 | from driftpy.constants.numeric_constants import BASE_PRECISION, AMM_RESERVE_PRECISION 33 | 34 | from anchorpy import Provider, Wallet 35 | from solana.rpc.async_api import AsyncClient 36 | 37 | # load keypair from file 38 | KEYPATH = '../your-keypair-secret.json' 39 | with open(KEYPATH, 'r') as f: 40 | secret = json.load(f) 41 | kp = Keypair.from_secret_key(bytes(secret)) 42 | 43 | # create clearing house for mainnet 44 | ENV = 'mainnet' 45 | config = configs[ENV] 46 | wallet = Wallet(kp) 47 | connection = AsyncClient(config.default_http) 48 | provider = Provider(connection, wallet) 49 | 50 | drift_client = DriftClient.from_config(config, provider) 51 | drift_user = DriftUser(drift_client) 52 | 53 | # open a 10 SOL long position 54 | sig = await drift_client.open_position( 55 | PositionDirection.LONG(), # long 56 | int(10 * BASE_PRECISION), # 10 in base precision 57 | 0, # sol market index 58 | ) 59 | 60 | # mint 100 LP shares on the SOL market 61 | await drift_client.add_liquidity( 62 | int(100 * AMM_RESERVE_PRECISION), 63 | 0, 64 | ) 65 | 66 | # inspect user's leverage 67 | leverage = await drift_user.get_leverage() 68 | print('current leverage:', leverage / 10_000) 69 | 70 | # you can also inspect other accounts information using the (authority=) flag 71 | bigz_acc = DriftUser(drift_client, user_public_key=PublicKey('bigZ')) 72 | leverage = await bigz_acc.get_leverage() 73 | print('bigZs leverage:', leverage / 10_000) 74 | 75 | # clearing house user calls can be expensive on the rpc so we can cache them 76 | drift_user = DriftUser(drift_client, use_cache=True) 77 | await drift_user.set_cache() 78 | 79 | # works without any rpc calls (uses the cached data) 80 | upnl = await drift_user.get_unrealized_pnl(with_funding=True) 81 | print('upnl:', upnl) 82 | ``` 83 | -------------------------------------------------------------------------------- /examples/check_pyth_lazer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from anchorpy.provider import Wallet 5 | from dotenv import load_dotenv 6 | from solana.rpc.async_api import AsyncClient 7 | 8 | from driftpy.drift_client import DriftClient 9 | from driftpy.types import is_variant 10 | 11 | 12 | async def main(): 13 | load_dotenv() 14 | url = os.getenv("DEVNET_RPC_ENDPOINT", "https://api.devnet.solana.com") 15 | connection = AsyncClient(url) 16 | print("RPC URL:", url) 17 | 18 | print("Checking devnet constants") 19 | drift_client = DriftClient( 20 | connection, 21 | Wallet.dummy(), 22 | env="devnet", 23 | ) 24 | 25 | print("Subscribing to Drift Client") 26 | await drift_client.subscribe() 27 | received_perp_markets = sorted( 28 | drift_client.get_perp_market_accounts(), key=lambda market: market.market_index 29 | ) 30 | for market in received_perp_markets: 31 | oracle_data = drift_client.get_user().get_oracle_data_for_perp_market( 32 | market.market_index 33 | ) 34 | if oracle_data and ( 35 | is_variant(market.amm.oracle_source, "PythLazer") 36 | or is_variant(market.amm.oracle_source, "PythLazer1K") 37 | or is_variant(market.amm.oracle_source, "PythLazer1M") 38 | ): 39 | print( 40 | market.market_index, 41 | market.amm.oracle, 42 | bytes(market.name).decode("utf-8").strip(), 43 | market.amm.oracle_source, 44 | oracle_data.price / 10**6, 45 | ) 46 | 47 | print("Subscribed to Drift Client") 48 | await drift_client.unsubscribe() 49 | await connection.close() 50 | 51 | 52 | if __name__ == "__main__": 53 | asyncio.run(main()) 54 | -------------------------------------------------------------------------------- /examples/deposit_and_withdraw.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from anchorpy.provider import Provider, Wallet 5 | from dotenv import load_dotenv 6 | from solana.rpc.async_api import AsyncClient 7 | 8 | from driftpy.drift_client import DriftClient 9 | from driftpy.keypair import load_keypair 10 | from driftpy.types import TxParams 11 | 12 | LAMPORTS_PER_SOL = 10**9 13 | 14 | 15 | load_dotenv() 16 | 17 | 18 | async def make_spot_trade(): 19 | rpc = os.environ.get("RPC_TRITON") 20 | secret = os.environ.get("PRIVATE_KEY") 21 | kp = load_keypair(secret) 22 | wallet = Wallet(kp) 23 | print(f"Using wallet: {wallet.public_key}") 24 | 25 | connection = AsyncClient(rpc) 26 | provider = Provider(connection, wallet) 27 | drift_client = DriftClient( 28 | provider.connection, 29 | provider.wallet, 30 | env="mainnet", 31 | tx_params=TxParams(compute_units_price=85_000, compute_units=1_400_000), 32 | ) 33 | await drift_client.subscribe() 34 | 35 | print("Drift client subscribed") 36 | 37 | amount = int(0.4 * LAMPORTS_PER_SOL) 38 | 39 | print("Depositing 0.4") 40 | await drift_client.deposit( 41 | amount=amount, 42 | spot_market_index=1, 43 | user_token_account=drift_client.wallet.public_key, 44 | ) 45 | print("Deposited 0.4") 46 | 47 | print("Withdrawing 0.4") 48 | await drift_client.withdraw( 49 | amount=amount, 50 | market_index=1, 51 | user_token_account=drift_client.wallet.public_key, 52 | ) 53 | print("Withdrew 0.4") 54 | 55 | 56 | if __name__ == "__main__": 57 | asyncio.run(make_spot_trade()) 58 | -------------------------------------------------------------------------------- /examples/example_grpc_log_provider.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from dotenv import load_dotenv 5 | from solana.rpc.commitment import Commitment 6 | 7 | from driftpy.events.grpc_log_provider import GrpcLogProvider 8 | from driftpy.events.types import GrpcLogProviderConfig 9 | 10 | load_dotenv() 11 | 12 | 13 | async def simple_log_callback(signature: str, slot: int, logs: list[str]): 14 | print(f"Received logs for tx: {signature} in slot: {slot}") 15 | print("---") 16 | 17 | 18 | async def main(): 19 | GRPC_ENDPOINT = os.getenv("GRPC_ENDPOINT") 20 | AUTH_TOKEN = os.getenv("GRPC_AUTH_TOKEN") 21 | if not GRPC_ENDPOINT or not AUTH_TOKEN: 22 | raise ValueError( 23 | "GRPC_ENDPOINT and GRPC_AUTH_TOKEN must be set in the environment" 24 | ) 25 | 26 | USER_ACCOUNT_TO_FILTER = "BrRpSaQ6hFDw8darPCyP9Sw7sjydMFQqB4ECAotXSEci" 27 | print(f"Attempting to connect to gRPC endpoint: {GRPC_ENDPOINT}") 28 | 29 | grpc_provider_config = GrpcLogProviderConfig( 30 | endpoint=GRPC_ENDPOINT, 31 | token=AUTH_TOKEN, 32 | ) 33 | 34 | commitment = Commitment("confirmed") 35 | 36 | log_provider = GrpcLogProvider( 37 | grpc_config=grpc_provider_config, 38 | commitment=commitment, 39 | user_account_to_filter=USER_ACCOUNT_TO_FILTER, 40 | ) 41 | 42 | print("Subscribing to logs...") 43 | await log_provider.subscribe(simple_log_callback) 44 | 45 | run_duration = 60 46 | print(f"Listening for logs for {run_duration} seconds...") 47 | 48 | try: 49 | for i in range(run_duration): 50 | if not log_provider.is_subscribed(): 51 | print("Subscription lost unexpectedly.") 52 | break 53 | await asyncio.sleep(1) 54 | if i % 10 == 0 and i > 0: 55 | print(f"Still subscribed after {i} seconds...") 56 | 57 | except asyncio.CancelledError: 58 | print("Run cancelled.") 59 | finally: 60 | print("Unsubscribing...") 61 | await log_provider.unsubscribe() 62 | print("Example finished.") 63 | 64 | 65 | if __name__ == "__main__": 66 | try: 67 | asyncio.run(main()) 68 | except KeyboardInterrupt: 69 | print("Exiting due to KeyboardInterrupt...") 70 | -------------------------------------------------------------------------------- /examples/fetch_all_markets.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from anchorpy import Provider, Wallet 5 | from dotenv import load_dotenv 6 | from solana.rpc.async_api import AsyncClient 7 | from solders.keypair import Keypair 8 | 9 | from driftpy.constants.numeric_constants import MARGIN_PRECISION 10 | from driftpy.drift_client import AccountSubscriptionConfig, DriftClient 11 | 12 | load_dotenv() 13 | 14 | 15 | async def get_all_market_names(): 16 | rpc = os.environ.get("MAINNET_RPC_ENDPOINT") 17 | kp = Keypair() # random wallet 18 | wallet = Wallet(kp) 19 | connection = AsyncClient(rpc) 20 | provider = Provider(connection, wallet) 21 | drift_client = DriftClient( 22 | provider.connection, 23 | provider.wallet, 24 | "mainnet", 25 | account_subscription=AccountSubscriptionConfig("cached"), 26 | ) 27 | await drift_client.subscribe() 28 | all_perps_markets = await drift_client.program.account["PerpMarket"].all() 29 | sorted_all_perps_markets = sorted( 30 | all_perps_markets, key=lambda x: x.account.market_index 31 | ) 32 | result_perp = [ 33 | bytes(x.account.name).decode("utf-8").strip() for x in sorted_all_perps_markets 34 | ] 35 | print("Perp Markets:") 36 | for index, market in enumerate(result_perp): 37 | max_leverage = get_perp_market_max_leverage(drift_client, index) 38 | print(f"{market} - {max_leverage}") 39 | 40 | result = result_perp + result_spot[1:] 41 | return result 42 | 43 | 44 | def get_perp_market_max_leverage(drift_client, market_index: int) -> float: 45 | market = drift_client.get_perp_market_account(market_index) 46 | standard_max_leverage = MARGIN_PRECISION / market.margin_ratio_initial 47 | 48 | high_leverage = ( 49 | MARGIN_PRECISION / market.high_leverage_margin_ratio_initial 50 | if market.high_leverage_margin_ratio_initial > 0 51 | else 0 52 | ) 53 | max_leverage = max(standard_max_leverage, high_leverage) 54 | return max_leverage 55 | 56 | 57 | if __name__ == "__main__": 58 | loop = asyncio.new_event_loop() 59 | answer = loop.run_until_complete(get_all_market_names()) 60 | print(answer) 61 | -------------------------------------------------------------------------------- /examples/get_all_users_pnl.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import csv 3 | import logging 4 | import os 5 | from datetime import datetime 6 | 7 | from anchorpy.provider import Provider, Wallet 8 | from dotenv import load_dotenv 9 | from solana.rpc.async_api import AsyncClient 10 | from solders.pubkey import Pubkey 11 | from tqdm import tqdm 12 | 13 | from driftpy.account_subscription_config import AccountSubscriptionConfig 14 | from driftpy.drift_client import DriftClient 15 | from driftpy.drift_user import DriftUser 16 | from driftpy.keypair import load_keypair 17 | from driftpy.user_map.user_map import UserMap 18 | from driftpy.user_map.user_map_config import PollingConfig, UserMapConfig 19 | 20 | logging.basicConfig(level=logging.INFO) 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | load_dotenv() 25 | 26 | 27 | async def main(): 28 | rpc = os.environ.get("MAINNET_RPC_ENDPOINT") 29 | private_key = os.environ.get("PRIVATE_KEY") 30 | kp = load_keypair(private_key) 31 | wallet = Wallet(kp) 32 | connection = AsyncClient(rpc) 33 | provider = Provider(connection, wallet) 34 | drift_client = DriftClient( 35 | provider.connection, 36 | provider.wallet, 37 | "mainnet", 38 | account_subscription=AccountSubscriptionConfig("cached"), 39 | ) 40 | await drift_client.subscribe() 41 | 42 | usermap_config = UserMapConfig(drift_client, PollingConfig(frequency=2)) 43 | usermap = UserMap(usermap_config) 44 | await usermap.subscribe() 45 | # make a copy of the usermap 46 | usermap_copy = list(usermap.user_map.keys()) 47 | 48 | # Setup CSV output 49 | filename = f"pnl_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" 50 | with open(filename, "w", newline="") as csvfile: 51 | writer = csv.writer(csvfile) 52 | writer.writerow( 53 | [ 54 | "User", 55 | "Authority", 56 | "Realized PnL", 57 | "Unrealized PnL", 58 | "Total PnL", 59 | "Total Collateral", 60 | ] 61 | ) 62 | 63 | # Calculate PnL 64 | for user_pubkey in tqdm(usermap_copy): 65 | try: 66 | user_pubkey = Pubkey.from_string(user_pubkey) 67 | user = DriftUser( 68 | drift_client, 69 | user_public_key=user_pubkey, 70 | account_subscription=AccountSubscriptionConfig("cached"), 71 | ) 72 | authority = str(user.get_user_account().authority) 73 | realized_pnl = user.get_user_account().settled_perp_pnl 74 | unrealized_pnl = user.get_unrealized_pnl(with_funding=True) 75 | total_pnl = realized_pnl + unrealized_pnl 76 | collateral = user.get_total_collateral() 77 | 78 | writer.writerow( 79 | [ 80 | str(user_pubkey), 81 | authority, 82 | f"{realized_pnl:.2f}", 83 | f"{unrealized_pnl:.2f}", 84 | f"{total_pnl:.2f}", 85 | f"{collateral:.2f}", 86 | ] 87 | ) 88 | except Exception as e: 89 | logger.error(f"Error calculating PnL for {user_pubkey}: {e}") 90 | 91 | logger.info(f"CSV report generated: {filename}") 92 | 93 | 94 | if __name__ == "__main__": 95 | asyncio.run(main()) 96 | -------------------------------------------------------------------------------- /examples/get_borrows.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from anchorpy.provider import Wallet 5 | from dotenv import load_dotenv 6 | from solana.rpc.async_api import AsyncClient 7 | 8 | from driftpy.accounts.get_accounts import get_spot_market_account 9 | from driftpy.drift_client import DriftClient 10 | from driftpy.math.spot_balance import ( 11 | calculate_borrow_rate, 12 | calculate_deposit_rate, 13 | calculate_interest_rate, 14 | calculate_utilization, 15 | ) 16 | from driftpy.math.spot_market import get_token_amount 17 | from driftpy.types import SpotBalanceType 18 | 19 | 20 | async def main(): 21 | load_dotenv() 22 | url = os.getenv("RPC_URL") 23 | connection = AsyncClient(url) 24 | dc = DriftClient(connection, Wallet.dummy(), "mainnet") 25 | market = await get_spot_market_account(dc.program, 0) # USDC 26 | if market is None: 27 | raise Exception("No market found") 28 | token_deposit_amount = get_token_amount( 29 | market.deposit_balance, 30 | market, 31 | SpotBalanceType.Deposit(), # type: ignore 32 | ) 33 | 34 | token_borrow_amount = get_token_amount( 35 | market.borrow_balance, 36 | market, 37 | SpotBalanceType.Borrow(), # type: ignore 38 | ) 39 | print(f"token_deposit_amount: {(token_deposit_amount/10**market.decimals):,.2f}") 40 | print(f"token_borrow_amount: {(token_borrow_amount/10**market.decimals):,.2f}") 41 | 42 | borrow_rate = calculate_borrow_rate(market) 43 | deposit_rate = calculate_deposit_rate(market) 44 | utilization = calculate_utilization(market) 45 | interest_rate = calculate_interest_rate(market) 46 | 47 | precision = 10000 48 | print(f"borrow_rate: {borrow_rate/precision:.2f}%") 49 | print(f"deposit_rate: {deposit_rate/precision:.2f}%") 50 | print(f"utilization: {utilization/precision:.2f}%") 51 | print(f"interest_rate: {interest_rate/precision:.2f}%") 52 | 53 | 54 | if __name__ == "__main__": 55 | asyncio.run(main()) 56 | -------------------------------------------------------------------------------- /examples/get_deposits.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from anchorpy.provider import Wallet 5 | from dotenv import load_dotenv 6 | from solana.rpc.async_api import AsyncClient 7 | 8 | from driftpy.account_subscription_config import AccountSubscriptionConfig 9 | from driftpy.constants.numeric_constants import SPOT_BALANCE_PRECISION 10 | from driftpy.drift_client import DriftClient 11 | from driftpy.market_map.market_map import MarketMap 12 | from driftpy.market_map.market_map_config import MarketMapConfig 13 | from driftpy.market_map.market_map_config import ( 14 | WebsocketConfig as MarketMapWebsocketConfig, 15 | ) 16 | from driftpy.pickle.vat import Vat 17 | from driftpy.types import MarketType, is_variant 18 | from driftpy.user_map.user_map import UserMap 19 | from driftpy.user_map.user_map_config import UserMapConfig, UserStatsMapConfig 20 | from driftpy.user_map.user_map_config import ( 21 | WebsocketConfig as UserMapWebsocketConfig, 22 | ) 23 | from driftpy.user_map.userstats_map import UserStatsMap 24 | 25 | load_dotenv() 26 | 27 | 28 | def get_deposits_by_authority(vat: Vat, market_index: int): 29 | deposits = {} 30 | 31 | for user in vat.users.values(): 32 | for position in user.get_user_account().spot_positions: 33 | if ( 34 | position.market_index == market_index 35 | and position.scaled_balance > 0 36 | and not is_variant(position.balance_type, "Borrow") 37 | ): 38 | authority = user.user_public_key 39 | balance = position.scaled_balance / SPOT_BALANCE_PRECISION 40 | 41 | if authority in deposits: 42 | deposits[authority] += balance 43 | else: 44 | deposits[authority] = balance 45 | 46 | return { 47 | "deposits": [ 48 | {"authority": authority, "balance": balance} 49 | for authority, balance in sorted( 50 | deposits.items(), key=lambda x: x[1], reverse=True 51 | ) 52 | ] 53 | } 54 | 55 | 56 | async def main(): 57 | rpc_url = os.getenv("MAINNET_RPC_ENDPOINT") 58 | if not rpc_url: 59 | raise ValueError("MAINNET_RPC_ENDPOINT is not set") 60 | 61 | connection = AsyncClient(rpc_url) 62 | wallet = Wallet.dummy() 63 | dc = DriftClient( 64 | connection, 65 | wallet, 66 | "mainnet", 67 | account_subscription=AccountSubscriptionConfig("cached"), 68 | ) 69 | perp_map = MarketMap( 70 | MarketMapConfig( 71 | dc.program, 72 | MarketType.Perp(), # type: ignore 73 | MarketMapWebsocketConfig(), 74 | dc.connection, 75 | ) 76 | ) 77 | spot_map = MarketMap( 78 | MarketMapConfig( 79 | dc.program, 80 | MarketType.Spot(), # type: ignore 81 | MarketMapWebsocketConfig(), 82 | dc.connection, 83 | ) 84 | ) 85 | user_map = UserMap(UserMapConfig(dc, UserMapWebsocketConfig())) 86 | stats_map = UserStatsMap(UserStatsMapConfig(dc)) 87 | await asyncio.gather( 88 | asyncio.create_task(spot_map.subscribe()), 89 | asyncio.create_task(perp_map.subscribe()), 90 | asyncio.create_task(user_map.subscribe()), 91 | asyncio.create_task(stats_map.subscribe()), 92 | ) 93 | print("Subscribed to Drift Client") 94 | await user_map.sync() 95 | print("Synced User Map") 96 | 97 | vat = Vat( 98 | dc, 99 | user_map, 100 | stats_map, 101 | spot_map, 102 | perp_map, 103 | ) 104 | deposits = get_deposits_by_authority(vat, 0) 105 | return deposits 106 | 107 | 108 | if __name__ == "__main__": 109 | deposits = asyncio.run(main()) 110 | print(deposits) 111 | -------------------------------------------------------------------------------- /examples/get_funding.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from anchorpy.provider import Wallet 5 | from dotenv import load_dotenv 6 | from solana.rpc.async_api import AsyncClient 7 | 8 | from driftpy.accounts.get_accounts import get_perp_market_account 9 | from driftpy.accounts.oracle import get_oracle_price_data_and_slot 10 | from driftpy.constants import FUNDING_RATE_PRECISION, QUOTE_PRECISION 11 | from driftpy.drift_client import DriftClient 12 | from driftpy.math.funding import ( 13 | calculate_live_mark_twap, 14 | calculate_long_short_funding_and_live_twaps, 15 | ) 16 | 17 | 18 | async def main(): 19 | load_dotenv() 20 | url = os.getenv("RPC_URL") 21 | connection = AsyncClient(url) 22 | dc = DriftClient(connection, Wallet.dummy(), "mainnet") 23 | await dc.subscribe() 24 | 25 | market = await get_perp_market_account(dc.program, 0) # SOL-PERP 26 | if market is None: 27 | raise Exception("No market found") 28 | 29 | oracle_price = await get_oracle_price_data_and_slot( 30 | connection, market.amm.oracle, market.amm.oracle_source 31 | ) 32 | oracle_price_data = oracle_price.data 33 | 34 | now = int(asyncio.get_event_loop().time()) 35 | mark_price = market.amm.historical_oracle_data.last_oracle_price 36 | 37 | ( 38 | mark_twap, 39 | oracle_twap, 40 | long_rate, 41 | short_rate, 42 | ) = await calculate_long_short_funding_and_live_twaps( 43 | market, oracle_price_data, mark_price, now 44 | ) 45 | 46 | precision = FUNDING_RATE_PRECISION 47 | print(f"Long Funding Rate: {long_rate/precision}%") 48 | print( 49 | f"Last 24h Avg Funding Rate: {market.amm.last24h_avg_funding_rate/precision}%" 50 | ) 51 | print(f"Last Funding Rate: {market.amm.last_funding_rate/precision}%") 52 | print(f"Last Funding Rate Long: {market.amm.last_funding_rate_long/precision}%") 53 | print(f"Last Funding Rate Short: {market.amm.last_funding_rate_short/precision}%") 54 | 55 | print(f"Oracle Price TWAP: ${oracle_twap/QUOTE_PRECISION:.2f}") 56 | 57 | live_mark_twap = calculate_live_mark_twap( 58 | market, oracle_price_data, mark_price, now 59 | ) 60 | print(f"Live Mark TWAP: ${live_mark_twap/QUOTE_PRECISION:.2f}") 61 | 62 | 63 | if __name__ == "__main__": 64 | asyncio.run(main()) 65 | -------------------------------------------------------------------------------- /examples/get_vault_metrics.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | import dotenv 5 | from solana.rpc.async_api import AsyncClient 6 | 7 | from driftpy.vaults import VaultClient 8 | 9 | dotenv.load_dotenv() 10 | 11 | connection = AsyncClient(os.getenv("RPC_TRITON")) 12 | 13 | 14 | def format_vault_summary(analytics, top_n=5): 15 | """Format vault data for readable output""" 16 | output = [] 17 | output.append(f"Total Vaults: {analytics['total_vaults']}") 18 | output.append(f"Total Value Locked: {analytics['total_deposits']:,.2f}") 19 | output.append(f"Total Unique Depositors: {analytics['total_depositors']}") 20 | 21 | output.append("\n----- Top Vaults by Deposits -----") 22 | for i, vault in enumerate(analytics["top_by_deposits"][:top_n], 1): 23 | output.append(f"{i}. {vault['name']}: {vault['true_net_deposits']:,.2f}") 24 | 25 | output.append("\n----- Top Vaults by Users -----") 26 | for i, vault in enumerate(analytics["top_by_users"][:top_n], 1): 27 | output.append(f"{i}. {vault['name']}: {vault['depositor_count']} users") 28 | 29 | return "\n".join(output) 30 | 31 | 32 | async def main(): 33 | vault_client = await VaultClient(connection).initialize() 34 | print("Got vault client...") 35 | analytics = await vault_client.calculate_analytics() 36 | print("Got analytics...") 37 | depositors = await vault_client.get_all_depositors() 38 | print(f"Got {len(depositors)} depositors...") 39 | print("Top 5 depositors:") 40 | for i, dep in enumerate(depositors[:5], 1): 41 | print(f"{i}. {dep}") 42 | 43 | print(format_vault_summary(analytics)) 44 | drift_boost = await vault_client.get_vault_by_name("SOL-NL-Neutral-Trade") 45 | 46 | if drift_boost: 47 | depositors = await vault_client.get_vault_depositors_with_stats( 48 | drift_boost.account.pubkey 49 | ) 50 | print("\nTop 5 depositors:") 51 | for i, dep in enumerate(depositors[:5], 1): 52 | print( 53 | f"{i}. {dep['pubkey']}: {dep['shares']}, {dep['share_percentage']:.2f}%" 54 | ) 55 | 56 | 57 | if __name__ == "__main__": 58 | asyncio.run(main()) 59 | -------------------------------------------------------------------------------- /examples/grpc_auction.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from anchorpy.provider import Provider, Wallet 5 | from dotenv import load_dotenv 6 | from solana.rpc.async_api import AsyncClient 7 | from solders.keypair import Keypair 8 | 9 | from driftpy.account_subscription_config import AccountSubscriptionConfig 10 | from driftpy.auction_subscriber.grpc_auction_subscriber import GrpcAuctionSubscriber 11 | 12 | # Auction subscriber imports 13 | from driftpy.auction_subscriber.types import GrpcAuctionSubscriberConfig 14 | 15 | # Drift imports 16 | from driftpy.drift_client import DriftClient 17 | from driftpy.types import GrpcConfig, UserAccount 18 | 19 | load_dotenv() 20 | 21 | 22 | # Example callback 23 | def on_auction_account_update(user_account: UserAccount, pubkey, slot: int): 24 | """ 25 | This function is called whenever an auctioned User account is updated. 26 | """ 27 | print( 28 | f"[AUCTION UPDATE] Slot={slot} PublicKey={pubkey}, UserAccount={user_account}" 29 | ) 30 | 31 | 32 | async def main(): 33 | """ 34 | Main entrypoint: create a gRPC-based Drift client, set up the GrpcAuctionSubscriber, 35 | and attach a callback that prints changes to auction accounts. 36 | """ 37 | 38 | # 1) Load environment variables 39 | rpc_fqdn = os.environ.get("RPC_FQDN") # e.g. "grpcs://my-geyser-endpoint.com:443" 40 | x_token = os.environ.get("X_TOKEN") # your auth token 41 | private_key = os.environ.get("PRIVATE_KEY") # base58-encoded, e.g. "42Ab..." 42 | rpc_url = os.environ.get("RPC_TRITON") # normal Solana JSON-RPC for sending tx 43 | 44 | if not (rpc_fqdn and x_token and private_key and rpc_url): 45 | raise ValueError("RPC_FQDN, X_TOKEN, PRIVATE_KEY, and RPC_TRITON must be set") 46 | 47 | wallet = Wallet(Keypair.from_base58_string(private_key)) 48 | connection = AsyncClient(rpc_url) 49 | provider = Provider(connection, wallet) 50 | 51 | drift_client = DriftClient( 52 | provider.connection, 53 | provider.wallet, 54 | "mainnet", 55 | account_subscription=AccountSubscriptionConfig( 56 | "grpc", 57 | grpc_config=GrpcConfig( 58 | endpoint=rpc_fqdn, 59 | token=x_token, 60 | ), 61 | ), 62 | ) 63 | 64 | await drift_client.subscribe() 65 | 66 | auction_subscriber_config = GrpcAuctionSubscriberConfig( 67 | drift_client=drift_client, 68 | grpc_config=GrpcConfig(endpoint=rpc_fqdn, token=x_token), 69 | commitment=provider.connection.commitment, # or "confirmed" 70 | ) 71 | auction_subscriber = GrpcAuctionSubscriber(auction_subscriber_config) 72 | 73 | auction_subscriber.event_emitter.on_account_update += on_auction_account_update 74 | await auction_subscriber.subscribe() 75 | print("AuctionSubscriber is now watching for changes to in-auction User accounts.") 76 | 77 | try: 78 | while True: 79 | await asyncio.sleep(10) 80 | except KeyboardInterrupt: 81 | print("Unsubscribing from auction accounts...") 82 | auction_subscriber.unsubscribe() 83 | 84 | 85 | if __name__ == "__main__": 86 | asyncio.run(main()) 87 | print("Done.") 88 | -------------------------------------------------------------------------------- /examples/grpc_client.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from anchorpy.provider import Provider, Wallet 5 | from dotenv import load_dotenv 6 | from solana.rpc.async_api import AsyncClient 7 | from solana.rpc.commitment import Commitment 8 | from solders.keypair import Keypair 9 | 10 | from driftpy.drift_client import AccountSubscriptionConfig, DriftClient 11 | from driftpy.types import GrpcConfig 12 | 13 | load_dotenv() 14 | 15 | RED = "\033[91m" 16 | GREEN = "\033[92m" 17 | RESET = "\033[0m" 18 | 19 | CLEAR_SCREEN = "\033c" 20 | 21 | 22 | async def watch_drift_markets(): 23 | rpc_fqdn = os.environ.get("RPC_FQDN") 24 | x_token = os.environ.get("X_TOKEN") 25 | private_key = os.environ.get("PRIVATE_KEY") 26 | rpc_url = os.environ.get("RPC_TRITON") 27 | 28 | if not (rpc_fqdn and x_token and private_key and rpc_url): 29 | raise ValueError("RPC_FQDN, X_TOKEN, PRIVATE_KEY, and RPC_TRITON must be set") 30 | 31 | wallet = Wallet(Keypair.from_base58_string(private_key)) 32 | connection = AsyncClient(rpc_url) 33 | provider = Provider(connection, wallet) 34 | 35 | drift_client = DriftClient( 36 | provider.connection, 37 | provider.wallet, 38 | "mainnet", 39 | account_subscription=AccountSubscriptionConfig( 40 | "grpc", 41 | grpc_config=GrpcConfig( 42 | endpoint=rpc_fqdn, 43 | token=x_token, 44 | commitment=Commitment("confirmed"), 45 | ), 46 | ), 47 | ) 48 | 49 | await drift_client.subscribe() 50 | print("Subscribed via gRPC. Listening for market updates...") 51 | 52 | previous_prices = {} 53 | 54 | while True: 55 | print(CLEAR_SCREEN, end="") 56 | 57 | perp_markets = drift_client.get_perp_market_accounts() 58 | 59 | if not perp_markets: 60 | print(f"{RED}No perp markets found (yet){RESET}") 61 | else: 62 | print("Drift Perp Markets (gRPC subscription)\n") 63 | perp_markets.sort(key=lambda x: x.market_index) 64 | for market in perp_markets[:20]: 65 | market_index = market.market_index 66 | last_price = market.amm.historical_oracle_data.last_oracle_price / 1e6 67 | 68 | if market_index in previous_prices: 69 | old_price = previous_prices[market_index] 70 | if last_price > old_price: 71 | color = GREEN 72 | elif last_price < old_price: 73 | color = RED 74 | else: 75 | color = RESET 76 | else: 77 | color = RESET 78 | 79 | print( 80 | f"Market Index: {market_index} | " 81 | f"Price: {color}${last_price:.4f}{RESET}" 82 | ) 83 | 84 | previous_prices[market_index] = last_price 85 | 86 | await asyncio.sleep(1) 87 | 88 | 89 | if __name__ == "__main__": 90 | asyncio.run(watch_drift_markets()) 91 | -------------------------------------------------------------------------------- /examples/high_leverage_users.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from anchorpy import Wallet 5 | from dotenv import load_dotenv 6 | from solana.rpc.async_api import AsyncClient 7 | 8 | from driftpy.drift_client import DriftClient 9 | from driftpy.user_map.user_map import UserMap 10 | from driftpy.user_map.user_map_config import UserMapConfig 11 | from driftpy.user_map.user_map_config import ( 12 | WebsocketConfig as UserMapWebsocketConfig, 13 | ) 14 | 15 | 16 | async def main(): 17 | load_dotenv() 18 | url = os.getenv("RPC_TRITON") 19 | connection = AsyncClient(url) 20 | dc = DriftClient( 21 | connection, 22 | Wallet.dummy(), 23 | "mainnet", 24 | ) 25 | await dc.subscribe() 26 | user = UserMap(UserMapConfig(dc, UserMapWebsocketConfig(), include_idle=False)) 27 | await user.subscribe() 28 | 29 | high_leverage_users = [] 30 | keys = [] 31 | for key, user in user.user_map.items(): 32 | if user.is_high_leverage_mode(): 33 | high_leverage_users.append(user) 34 | keys.append(key) 35 | 36 | return high_leverage_users, keys 37 | 38 | 39 | if __name__ == "__main__": 40 | high_leverage_users, keys = asyncio.run(main()) 41 | keys.sort() 42 | print(f"Number of high leverage users: {len(keys)}") 43 | # with open("high_leverage_users.txt", "w") as f: 44 | # for key in keys: 45 | # f.write(f"{key}\n") 46 | -------------------------------------------------------------------------------- /examples/pmm_users.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from anchorpy import Wallet 5 | from dotenv import load_dotenv 6 | from solana.rpc.async_api import AsyncClient 7 | 8 | from driftpy.addresses import get_protected_maker_mode_config_public_key 9 | from driftpy.drift_client import DriftClient 10 | from driftpy.math.user_status import is_user_protected_maker 11 | from driftpy.user_map.user_map import UserMap 12 | from driftpy.user_map.user_map_config import UserMapConfig 13 | from driftpy.user_map.user_map_config import ( 14 | WebsocketConfig as UserMapWebsocketConfig, 15 | ) 16 | 17 | 18 | async def main(): 19 | load_dotenv() 20 | url = os.getenv("RPC_TRITON") 21 | connection = AsyncClient(url) 22 | dc = DriftClient( 23 | connection, 24 | Wallet.dummy(), 25 | "mainnet", 26 | ) 27 | await dc.subscribe() 28 | user = UserMap(UserMapConfig(dc, UserMapWebsocketConfig(), include_idle=True)) 29 | await user.subscribe() 30 | 31 | pmm_users = [] 32 | keys = [] 33 | print(get_protected_maker_mode_config_public_key(dc.program_id)) 34 | for key, user in user.user_map.items(): 35 | if is_user_protected_maker(user.get_user_account()): 36 | pmm_users.append(user) 37 | keys.append(key) 38 | 39 | return pmm_users, keys 40 | 41 | 42 | if __name__ == "__main__": 43 | pmm_users, keys = asyncio.run(main()) 44 | keys.sort() 45 | print(f"Number of protected maker users: {len(keys)}") 46 | -------------------------------------------------------------------------------- /examples/protected_order.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | 4 | from anchorpy.provider import Wallet 5 | from dotenv import load_dotenv 6 | from solana.rpc.async_api import AsyncClient 7 | from solders.keypair import Keypair 8 | 9 | from driftpy.addresses import get_protected_maker_mode_config_public_key 10 | from driftpy.drift_client import DriftClient 11 | from driftpy.math.user_status import is_user_protected_maker 12 | 13 | load_dotenv() 14 | 15 | 16 | async def is_protected_maker(): 17 | rpc = os.environ.get("RPC_TRITON") 18 | kp = Keypair.from_base58_string(os.environ.get("PRIVATE_KEY", "")) 19 | wallet = Wallet(kp) 20 | connection = AsyncClient(rpc) 21 | drift_client = DriftClient( 22 | connection=connection, 23 | wallet=wallet, 24 | env="mainnet", 25 | ) 26 | 27 | await drift_client.subscribe() 28 | config_account = await drift_client.program.account[ 29 | "ProtectedMakerModeConfig" 30 | ].fetch(get_protected_maker_mode_config_public_key(drift_client.program_id)) 31 | print(config_account) 32 | user = drift_client.get_user() 33 | print(user.get_open_orders()) 34 | print(get_protected_maker_mode_config_public_key(drift_client.program_id)) 35 | print( 36 | "Is user protected maker: ", 37 | is_user_protected_maker(user.get_user_account()), 38 | ) 39 | print(await drift_client.update_user_protected_maker_orders(0, True)) 40 | print( 41 | "Is user protected maker: ", 42 | is_user_protected_maker(user.get_user_account()), 43 | ) 44 | await drift_client.unsubscribe() 45 | 46 | 47 | if __name__ == "__main__": 48 | asyncio.run(is_protected_maker()) 49 | -------------------------------------------------------------------------------- /examples/readme.md: -------------------------------------------------------------------------------- 1 | ## Staking in the Insurance Fund 2 | 3 | How to run 4 | ```bash 5 | python if_stake.py 6 | --keypath ../keypairs/x19.json # this should be the keypair path 7 | --amount 100 # USD amount 8 | --market 0 # USD spot market 9 | --operation add 10 | --env mainnet 11 | ``` -------------------------------------------------------------------------------- /examples/settle_pnl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import os 4 | 5 | from anchorpy import Wallet 6 | from dotenv import load_dotenv 7 | from driftpy.drift_client import DriftClient 8 | from solana.rpc.async_api import AsyncClient 9 | from solders.keypair import Keypair 10 | from solders.pubkey import Pubkey 11 | 12 | 13 | load_dotenv() 14 | 15 | 16 | async def setup_drift_client() -> DriftClient: 17 | connection = AsyncClient(os.getenv("RPC_URL")) 18 | kp = Keypair.from_base58_string(os.getenv("PRIVATE_KEY")) 19 | wallet = Wallet(kp) 20 | drift_client = DriftClient(connection=connection, wallet=wallet, env="mainnet") 21 | await drift_client.subscribe() 22 | await drift_client.account_subscriber.fetch() 23 | print("drift_client subscribe done") 24 | return drift_client 25 | 26 | 27 | async def settle_pnl_once( 28 | drift_client: DriftClient, user_public_key: Pubkey, market_index: int 29 | ): 30 | user = drift_client.get_user() 31 | account_info = user.get_user_account_and_slot().data 32 | tx_hash = await drift_client.settle_pnl(user_public_key, account_info, market_index) 33 | print(f"Settled PNL for market {market_index}, tx: {tx_hash}") 34 | 35 | 36 | async def main(): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--subaccount_id", type=int, required=True) 39 | parser.add_argument("--market_index", type=int, required=True) 40 | args = parser.parse_args() 41 | 42 | try: 43 | drift_client = await setup_drift_client() 44 | user_public_key = drift_client.get_user_account_public_key(args.subaccount_id) 45 | await asyncio.wait_for( 46 | settle_pnl_once(drift_client, user_public_key, args.market_index), 47 | timeout=540, 48 | ) 49 | except: 50 | raise 51 | 52 | 53 | if __name__ == "__main__": 54 | asyncio.run(main()) 55 | -------------------------------------------------------------------------------- /examples/spot_market_trade.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import os 4 | 5 | from anchorpy.provider import Provider, Wallet 6 | from dotenv import load_dotenv 7 | from solana.rpc.async_api import AsyncClient 8 | 9 | from driftpy.constants.spot_markets import mainnet_spot_market_configs 10 | from driftpy.drift_client import DriftClient 11 | from driftpy.keypair import load_keypair 12 | from driftpy.types import TxParams 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | load_dotenv() 18 | 19 | 20 | def get_market_by_symbol(symbol: str): 21 | for market in mainnet_spot_market_configs: 22 | if market.symbol == symbol: 23 | return market 24 | raise Exception(f"Market {symbol} not found") 25 | 26 | 27 | async def make_spot_trade(): 28 | rpc = os.environ.get("RPC_TRITON") 29 | secret = os.environ.get("PRIVATE_KEY") 30 | kp = load_keypair(secret) 31 | wallet = Wallet(kp) 32 | logger.info(f"Using wallet: {wallet.public_key}") 33 | 34 | connection = AsyncClient(rpc) 35 | provider = Provider(connection, wallet) 36 | drift_client = DriftClient( 37 | provider.connection, 38 | provider.wallet, 39 | env="mainnet", 40 | tx_params=TxParams(compute_units_price=85_000, compute_units=1_400_000), 41 | ) 42 | await drift_client.subscribe() 43 | 44 | logger.info("Drift client subscribed") 45 | 46 | market_symbol_1 = "JLP" 47 | market_symbol_2 = "USDC" 48 | 49 | in_decimals_result = drift_client.get_spot_market_account( 50 | get_market_by_symbol(market_symbol_1).market_index 51 | ) 52 | if not in_decimals_result: 53 | logger.error("USDS market not found") 54 | raise Exception("Market not found") 55 | 56 | in_decimals = in_decimals_result.decimals 57 | logger.info(f"{market_symbol_1} decimals: {in_decimals}") 58 | 59 | swap_amount = int(1 * 10**in_decimals) 60 | logger.info(f"Swapping {swap_amount} {market_symbol_1} to {market_symbol_2}") 61 | 62 | swap_ixs, swap_lookups = await drift_client.get_jupiter_swap_ix_v6( 63 | out_market_idx=get_market_by_symbol(market_symbol_2).market_index, 64 | in_market_idx=get_market_by_symbol(market_symbol_1).market_index, 65 | amount=swap_amount, 66 | swap_mode="ExactIn", 67 | only_direct_routes=True, 68 | max_accounts=20, 69 | ) 70 | await drift_client.send_ixs( 71 | ixs=swap_ixs, 72 | lookup_tables=swap_lookups, 73 | ) 74 | logger.info("Swap complete") 75 | 76 | 77 | if __name__ == "__main__": 78 | asyncio.run(make_spot_trade()) 79 | -------------------------------------------------------------------------------- /examples/view.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import time 4 | from pprint import pprint 5 | 6 | import dotenv 7 | from anchorpy import Wallet 8 | from solana.rpc.async_api import AsyncClient 9 | from solders.pubkey import Pubkey 10 | 11 | from driftpy.account_subscription_config import AccountSubscriptionConfig 12 | from driftpy.constants.numeric_constants import ( 13 | MARGIN_PRECISION, 14 | QUOTE_PRECISION, 15 | ) 16 | from driftpy.decode.utils import decode_name 17 | from driftpy.drift_client import DriftClient 18 | from driftpy.drift_user import DriftUser 19 | from driftpy.keypair import load_keypair 20 | from driftpy.math.conversion import convert_to_number 21 | from driftpy.math.perp_position import is_available 22 | 23 | dotenv.load_dotenv() 24 | 25 | 26 | async def main(): 27 | s = time.time() 28 | kp = load_keypair(os.getenv("PRIVATE_KEY")) 29 | wallet = Wallet(kp) 30 | connection = AsyncClient(os.getenv("RPC_TRITON")) 31 | dc = DriftClient( 32 | connection, 33 | wallet, 34 | account_subscription=AccountSubscriptionConfig("websocket"), 35 | ) 36 | await dc.subscribe() 37 | drift_user = dc.get_user() 38 | user = drift_user.get_user_account() 39 | print("\n=== User Info ===") 40 | print(f"Subaccount name: {decode_name(user.name)}") 41 | 42 | spot_collateral = drift_user.get_spot_market_asset_value( 43 | None, 44 | include_open_orders=True, 45 | ) 46 | print("\n=== Collateral & PnL ===") 47 | print(f"Spot collateral: ${spot_collateral / QUOTE_PRECISION:,.2f}") 48 | 49 | pnl = drift_user.get_unrealized_pnl(False) 50 | print(f"Unrealized PnL: ${pnl / QUOTE_PRECISION:,.2f}") 51 | 52 | total_collateral = drift_user.get_total_collateral() 53 | print(f"Total collateral: ${total_collateral / QUOTE_PRECISION:,.2f}") 54 | 55 | perp_liability = drift_user.get_total_perp_position_liability() 56 | spot_liability = drift_user.get_spot_market_liability_value() 57 | print("\n=== Liabilities ===") 58 | print(f"Perp liability: ${perp_liability / QUOTE_PRECISION:,.2f}") 59 | print(f"Spot liability: ${spot_liability / QUOTE_PRECISION:,.2f}") 60 | 61 | perp_market = dc.get_perp_market_account(0) 62 | oracle = convert_to_number( 63 | dc.get_oracle_price_data_for_perp_market(0).price, QUOTE_PRECISION 64 | ) 65 | print("\n=== Market Info ===") 66 | print(f"Oracle price: ${oracle:,.2f}") 67 | 68 | init_leverage = MARGIN_PRECISION / perp_market.margin_ratio_initial 69 | maint_leverage = MARGIN_PRECISION / perp_market.margin_ratio_maintenance 70 | print(f"Initial leverage: {init_leverage:.2f}x") 71 | print(f"Maintenance leverage: {maint_leverage:.2f}x") 72 | 73 | liq_price = drift_user.get_perp_liq_price(0) 74 | print(f"Liquidation price: ${liq_price:,.2f}") 75 | 76 | total_liability = drift_user.get_margin_requirement(None) 77 | total_asset_value = drift_user.get_total_collateral() 78 | print("\n=== Risk Metrics ===") 79 | print(f"Total liability: ${total_liability / QUOTE_PRECISION:,.2f}") 80 | print(f"Total asset value: ${total_asset_value / QUOTE_PRECISION:,.2f}") 81 | print(f"Current leverage: {(drift_user.get_leverage()) / 10_000:.2f}x") 82 | 83 | user = drift_user.get_user_account() 84 | print("\n=== Perp Positions ===") 85 | for position in user.perp_positions: 86 | if not is_available(position): 87 | pprint(position) 88 | 89 | print(f"\nTime taken: {time.time() - s:.2f}s") 90 | 91 | orders = drift_user.get_open_orders() 92 | print("\n=== Orders ===") 93 | pprint(orders, indent=2, width=80) 94 | 95 | print("\n=== Health Components ===") 96 | pprint(drift_user.get_health_components(), indent=2, width=80) 97 | 98 | # Another user 99 | drift_user2 = DriftUser( 100 | drift_client=dc, 101 | user_public_key=Pubkey.from_string(os.getenv("RANDOM_USER_ACCOUNT_PUBKEY")), 102 | ) 103 | await drift_user2.subscribe() 104 | 105 | print("\n=== Perp Positions (User 2) ===") 106 | for position in drift_user2.get_user_account().perp_positions: 107 | if position.base_asset_amount != 0: 108 | pprint(position) 109 | 110 | pnl = drift_user2.get_net_usd_value() 111 | print(f"Net USD Value: ${pnl / QUOTE_PRECISION:,.2f}") 112 | 113 | 114 | if __name__ == "__main__": 115 | asyncio.run(main()) 116 | -------------------------------------------------------------------------------- /examples/ws_simple.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import os 4 | 5 | import dotenv 6 | from anchorpy.provider import Wallet 7 | from solana.rpc.async_api import AsyncClient 8 | from solders.keypair import Keypair 9 | 10 | from driftpy.accounts.ws.drift_client import WebsocketDriftClientAccountSubscriber 11 | from driftpy.drift_client import DriftClient 12 | from driftpy.keypair import load_keypair 13 | from driftpy.types import ( 14 | TxParams, 15 | ) 16 | 17 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s") 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | async def get_all_spot_indexes(rpc_url: str): 22 | throwaway_drift_client = DriftClient( 23 | connection=AsyncClient(rpc_url), 24 | wallet=Wallet(Keypair()), 25 | env="mainnet", 26 | spot_market_indexes=[], 27 | ) 28 | spot_markets = await throwaway_drift_client.program.account["SpotMarket"].all() 29 | await throwaway_drift_client.unsubscribe() 30 | await throwaway_drift_client.connection.close() 31 | return [market.account.market_index for market in spot_markets] 32 | 33 | 34 | async def get_all_perp_indexes(rpc_url: str): 35 | throwaway_drift_client = DriftClient( 36 | connection=AsyncClient(rpc_url), 37 | wallet=Wallet(Keypair()), 38 | env="mainnet", 39 | perp_market_indexes=[], 40 | ) 41 | perp_markets = await throwaway_drift_client.program.account["PerpMarket"].all() 42 | await throwaway_drift_client.unsubscribe() 43 | await throwaway_drift_client.connection.close() 44 | return [market.account.market_index for market in perp_markets] 45 | 46 | 47 | async def main(): 48 | logger.info("Starting...") 49 | dotenv.load_dotenv() 50 | rpc_url = os.getenv("RPC_TRITON") 51 | private_key = os.getenv("PRIVATE_KEY") 52 | if not rpc_url or not private_key: 53 | raise Exception("Missing env vars") 54 | kp = load_keypair(private_key) 55 | 56 | drift_client = DriftClient( 57 | connection=AsyncClient(rpc_url), 58 | wallet=Wallet(kp), 59 | env="mainnet", 60 | tx_params=TxParams(700_000, 10_000), 61 | ) 62 | await drift_client.subscribe() 63 | perp_markets = drift_client.get_perp_market_account(65) 64 | if perp_markets is None: 65 | raise Exception("No perp markets found") 66 | print(perp_markets) 67 | 68 | try: 69 | if not isinstance( 70 | drift_client.account_subscriber, WebsocketDriftClientAccountSubscriber 71 | ): 72 | raise Exception("Account subscriber is not a WebsocketAccountSubscriber") 73 | while True: 74 | print(perp_markets.amm.historical_oracle_data.last_oracle_price / 1e6) 75 | await asyncio.sleep(5) 76 | except KeyboardInterrupt: 77 | logger.info("Interrupted by user. Exiting loop...") 78 | finally: 79 | await drift_client.unsubscribe() 80 | await drift_client.connection.close() 81 | logger.info("Unsubscribed from Drift client.") 82 | 83 | 84 | if __name__ == "__main__": 85 | asyncio.run(main()) 86 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: DriftPy 2 | theme: 3 | name: material 4 | icon: 5 | logo: material/chart-line 6 | favicon: img/drift.png 7 | palette: 8 | - scheme: default 9 | primary: deep-purple 10 | toggle: 11 | icon: material/toggle-switch-off-outline 12 | name: Switch to dark mode 13 | - scheme: slate 14 | toggle: 15 | icon: material/toggle-switch 16 | name: Switch to light mode 17 | markdown_extensions: 18 | - pymdownx.highlight 19 | - pymdownx.superfences 20 | - admonition 21 | - pymdownx.snippets 22 | - meta 23 | - pymdownx.tabbed: 24 | alternate_style: true 25 | repo_url: https://github.com/drift-labs/driftpy 26 | repo_name: drift-labs/driftpy 27 | site_url: https://drift-labs.github.io/driftpy/ 28 | plugins: 29 | - mkdocstrings: 30 | handlers: 31 | python: 32 | selection: 33 | filters: 34 | - "!^_" # exlude all members starting with _ 35 | - "^__init__$" # but always include __init__ modules and methods 36 | - "!^T$" 37 | - "!^get_clearing_house_state_account_public_key_and_nonce$" 38 | rendering: 39 | show_root_heading: true 40 | show_root_full_path: false 41 | show_if_no_docstring: true 42 | - search 43 | nav: 44 | - index.md 45 | - clearing_house.md 46 | - clearing_house_user.md 47 | - accounts.md 48 | - addresses.md 49 | extra_css: 50 | - css/mkdocstrings.css 51 | - css/custom.css 52 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | in-project = true 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "driftpy" 3 | version = "0.8.56" 4 | description = "A Python client for the Drift DEX" 5 | authors = [ 6 | "x19 ", 7 | "bigz ", 8 | "frank ", 9 | "sina ", 10 | ] 11 | license = "MIT" 12 | readme = "README.md" 13 | homepage = "https://github.com/drift-labs/driftpy" 14 | documentation = "https://drift-labs.github.io/driftpy/" 15 | 16 | [tool.poetry.dependencies] 17 | python = "^3.10" 18 | anchorpy = "0.21.0" 19 | solana = "^0.36" 20 | requests = "^2.28.1" 21 | pythclient = "0.2.1" 22 | aiodns = "3.0.0" 23 | aiohttp = "^3.9.1" 24 | aiosignal = "1.3.1" 25 | anchorpy-core = "0.2.0" 26 | anyio = "4.4.0" 27 | apischema = "0.17.5" 28 | async-timeout = "^4.0.2" 29 | attrs = "22.2.0" 30 | backoff = "2.2.1" 31 | base58 = "2.1.1" 32 | based58 = "0.1.1" 33 | borsh-construct = "0.1.0" 34 | cachetools = "5.3" 35 | certifi = "2022.12.7" 36 | cffi = "1.15.1" 37 | charset-normalizer = "2.1.1" 38 | construct = "2.10.68" 39 | construct-typing = "0.5.3" 40 | dnspython = "2.2.1" 41 | exceptiongroup = "1.0.4" 42 | h11 = "0.14.0" 43 | httpcore = "1.0.7" 44 | httpx = "0.28.1" 45 | idna = "3.4" 46 | iniconfig = "1.1.1" 47 | jsonalias = "0.1.1" 48 | jsonrpcclient = "4.0.3" 49 | jsonrpcserver = "5.0.9" 50 | jsonschema = "4.18.0" 51 | loguru = "^0.7.0" 52 | mccabe = "0.7.0" 53 | more-itertools = "8.14.0" 54 | oslash = "0.6.3" 55 | packaging = "23.1" 56 | psutil = "5.9.4" 57 | py = "1.11.0" 58 | pycares = "4.3.0" 59 | pycodestyle = "2.10.0" 60 | pycparser = "2.21" 61 | pyflakes = "3.0.1" 62 | pyheck = "0.1.5" 63 | pyrsistent = "0.19.2" 64 | rfc3986 = "1.5.0" 65 | sniffio = "1.3.0" 66 | solders = ">=0.23.0,<0.27.0" 67 | sumtypes = "0.1a6" 68 | toml = "0.10.2" 69 | tomli = "2.0.1" 70 | toolz = "0.11.2" 71 | types-cachetools = "4.2.10" 72 | typing-extensions = "^4.4.0" 73 | urllib3 = "1.26.13" 74 | websockets = "13.0" 75 | yarl = "1.9.4" 76 | zstandard = "0.18.0" 77 | deprecated = "^1.2.14" 78 | events = "^0.5" 79 | numpy = "^1.26.2" 80 | grpcio = "1.68.1" 81 | protobuf = "5.29.2" 82 | pynacl = "^1.5.0" 83 | tqdm = "^4.67.1" 84 | 85 | [tool.poetry.group.dev.dependencies] 86 | pytest = "^7.4.4" 87 | flake8 = "6.0.0" 88 | black = "24.4.2" 89 | pytest-asyncio = "0.21.0" 90 | mkdocs = "^1.3.0" 91 | mkdocstrings = "^0.17.0" 92 | mkdocs-material = "^8.1.8" 93 | bump2version = "^1.0.1" 94 | autopep8 = "^2.0.4" 95 | mypy = "^1.7.0" 96 | python-dotenv = "^1.0.0" 97 | ruff = "^0.8.4" 98 | # drift-jit-proxy = ">=0.1.6" 99 | pytest-xprocess = "0.18.1" 100 | types-requests = "^2.28.9" 101 | jinja2 = "^3.1.2" 102 | 103 | [build-system] 104 | requires = ["poetry-core>=1.0.0"] 105 | build-backend = "poetry.core.masonry.api" 106 | 107 | [tool.pytest.ini_options] 108 | asyncio_mode = "strict" 109 | 110 | [tool.ruff] 111 | exclude = [ 112 | ".git", 113 | "__pycache__", 114 | "docs/source/conf.py", 115 | "old", 116 | "build", 117 | "dist", 118 | "**/geyser_codegen/**", 119 | ] 120 | 121 | [tool.ruff.lint.pycodestyle] 122 | max-line-length = 88 123 | 124 | [tool.pyright] 125 | reportMissingModuleSource = false 126 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiodns==3.0.0 2 | aiohttp==3.8.3 3 | aiosignal==1.3.1 4 | anchorpy==0.20.1 5 | anchorpy-core==0.2.0 6 | anyio==3.6.2 7 | apischema==0.17.5 8 | async-timeout==4.0.2 9 | attrs==22.1.0 10 | backoff==2.2.1 11 | base58==2.1.1 12 | based58==0.1.1 13 | borsh-construct==0.1.0 14 | cachetools==4.2.4 15 | certifi==2022.12.7 16 | cffi==1.15.1 17 | charset-normalizer==2.1.1 18 | construct==2.10.68 19 | construct-typing==0.5.3 20 | dnspython==2.2.1 21 | exceptiongroup==1.0.4 22 | flake8==6.0.0 23 | frozenlist==1.3.3 24 | h11==0.14.0 25 | httpcore==0.16.3 26 | httpx==0.23.1 27 | idna==3.4 28 | iniconfig==1.1.1 29 | jsonalias==0.1.1 30 | jsonrpcclient==4.0.2 31 | jsonrpcserver==5.0.9 32 | jsonschema==4.17.3 33 | loguru==0.6.0 34 | mccabe==0.7.0 35 | more-itertools==8.14.0 36 | multidict==6.0.3 37 | OSlash==0.6.3 38 | packaging==22.0 39 | pluggy==1.0.0 40 | psutil==5.9.4 41 | py==1.11.0 42 | pycares==4.3.0 43 | pycodestyle==2.10.0 44 | pycparser==2.21 45 | pyflakes==3.0.1 46 | pyheck==0.1.5 47 | pyrsistent==0.19.2 48 | pytest==6.2.5 49 | pytest-asyncio==0.17.2 50 | pytest-xprocess==0.18.1 51 | pythclient==0.2.1 52 | requests==2.28.1 53 | rfc3986==1.5.0 54 | sniffio==1.3.0 55 | solana==0.36.0 56 | solders==0.21.0 57 | sumtypes==0.1a6 58 | toml==0.10.2 59 | tomli==2.0.1 60 | toolz==0.11.2 61 | types-cachetools==4.2.10 62 | typing_extensions==4.4.0 63 | urllib3==1.26.13 64 | websockets==10.4 65 | yarl==1.8.2 66 | zstandard==0.17.0 67 | -------------------------------------------------------------------------------- /scripts/bump.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | 4 | 5 | def bump_version(version): 6 | major, minor, patch = map(int, version.split(".")) 7 | patch += 1 8 | return f"{major}.{minor}.{patch}" 9 | 10 | 11 | def update_file(file_path, pattern, replacement): 12 | with open(file_path, "r") as file: 13 | content = file.read() 14 | 15 | updated_content = re.sub(pattern, replacement, content) 16 | 17 | with open(file_path, "w") as file: 18 | file.write(updated_content) 19 | 20 | 21 | def main(): 22 | script_dir = os.path.dirname(os.path.abspath(__file__)) 23 | project_root = os.path.dirname(script_dir) 24 | 25 | pyproject_path = os.path.join(project_root, "pyproject.toml") 26 | pyproject_pattern = r'(version\s*=\s*["\'])(\d+\.\d+\.\d+)(["\'])' 27 | 28 | init_path = os.path.join(project_root, "src", "driftpy", "__init__.py") 29 | init_pattern = r'(__version__\s*=\s*["\'])(\d+\.\d+\.\d+)(["\'])' 30 | 31 | bumpversion_path = os.path.join(project_root, ".bumpversion.cfg") 32 | bumpversion_pattern = r"(current_version\s*=\s*)(\d+\.\d+\.\d+)" 33 | 34 | with open(pyproject_path, "r") as file: 35 | content = file.read() 36 | match = re.search(pyproject_pattern, content) 37 | if match: 38 | current_version = match.group(2) 39 | else: 40 | print("Couldn't find version in pyproject.toml") 41 | return 42 | 43 | new_version = bump_version(current_version) 44 | 45 | update_file(pyproject_path, pyproject_pattern, rf"\g<1>{new_version}\g<3>") 46 | update_file(init_path, init_pattern, rf"\g<1>{new_version}\g<3>") 47 | update_file(bumpversion_path, bumpversion_pattern, rf"\g<1>{new_version}") 48 | 49 | print(f"Version bumped from {current_version} to {new_version}") 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /scripts/ci.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "Running tests:" 5 | pytest -v -s -x tests/ci/*.py 6 | pytest -v -s tests/math/*.py 7 | 8 | exit_code=$? 9 | 10 | if [ $exit_code -ne 0 ]; then 11 | echo "Tests failed with exit code $exit_code" 12 | exit $exit_code 13 | fi 14 | 15 | echo "All tests passed successfully" -------------------------------------------------------------------------------- /scripts/decode.sh: -------------------------------------------------------------------------------- 1 | pytest -v -s tests/decode/*.py -------------------------------------------------------------------------------- /scripts/dlob.sh: -------------------------------------------------------------------------------- 1 | pytest -v -s -x tests/dlob/dlob.py -------------------------------------------------------------------------------- /scripts/generate_constants.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import os 4 | 5 | import dotenv 6 | from anchorpy import Wallet 7 | from solana.rpc.async_api import AsyncClient 8 | 9 | from driftpy.account_subscription_config import AccountSubscriptionConfig 10 | from driftpy.constants.perp_markets import PerpMarketConfig 11 | from driftpy.constants.spot_markets import SpotMarketConfig 12 | from driftpy.drift_client import DriftClient 13 | 14 | dotenv.load_dotenv() 15 | 16 | 17 | def decode_name(name) -> str: 18 | return bytes(name).decode("utf-8").strip() 19 | 20 | 21 | async def generate_spot_configs(drift_client: DriftClient) -> str: 22 | spot_markets = sorted( 23 | drift_client.get_spot_market_accounts(), key=lambda market: market.market_index 24 | ) 25 | 26 | configs = [] 27 | for market in spot_markets: 28 | config = SpotMarketConfig( 29 | symbol=decode_name(market.name), 30 | market_index=market.market_index, 31 | oracle=market.oracle, 32 | oracle_source=market.oracle_source, 33 | mint=market.mint, 34 | ) 35 | configs.append(config) 36 | 37 | output = """ 38 | mainnet_spot_market_configs: list[SpotMarketConfig] = [""" 39 | 40 | for config in configs: 41 | output += f""" 42 | SpotMarketConfig( 43 | symbol="{config.symbol}", 44 | market_index={config.market_index}, 45 | oracle=Pubkey.from_string("{str(config.oracle)}"), 46 | oracle_source=OracleSource.{config.oracle_source.__class__.__name__}(), # type: ignore 47 | mint=Pubkey.from_string("{str(config.mint)}"), 48 | ),""" 49 | 50 | output += "\n]\n" 51 | return output 52 | 53 | 54 | async def generate_perp_configs(drift_client: DriftClient) -> str: 55 | perp_markets = sorted( 56 | drift_client.get_perp_market_accounts(), key=lambda market: market.market_index 57 | ) 58 | 59 | configs = [] 60 | for market in perp_markets: 61 | config = PerpMarketConfig( 62 | symbol=decode_name(market.name), 63 | base_asset_symbol="-".join(decode_name(market.name).split("-")[:-1]), 64 | market_index=market.market_index, 65 | oracle=market.amm.oracle, 66 | oracle_source=market.amm.oracle_source, 67 | ) 68 | configs.append(config) 69 | 70 | # Generate Python code 71 | output = """ 72 | mainnet_perp_market_configs: list[PerpMarketConfig] = [""" 73 | 74 | for config in configs: 75 | output += f""" 76 | PerpMarketConfig( 77 | symbol="{config.symbol}", 78 | base_asset_symbol="{config.base_asset_symbol}", 79 | market_index={config.market_index}, 80 | oracle=Pubkey.from_string("{str(config.oracle)}"), 81 | oracle_source=OracleSource.{config.oracle_source.__class__.__name__}(), # type: ignore 82 | ),""" 83 | 84 | output += "\n]\n" 85 | 86 | return output 87 | 88 | 89 | async def main(): 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument("--market-type", choices=["perp", "spot"], required=True) 92 | args = parser.parse_args() 93 | 94 | rpc_url = os.getenv("MAINNET_RPC_ENDPOINT") 95 | if not rpc_url: 96 | raise ValueError("MAINNET_RPC_ENDPOINT is not set") 97 | 98 | drift_client = DriftClient( 99 | AsyncClient(rpc_url), 100 | Wallet.dummy(), 101 | env="mainnet", 102 | account_subscription=AccountSubscriptionConfig("cached"), 103 | ) 104 | 105 | await drift_client.subscribe() 106 | 107 | if args.market_type == "perp": 108 | output = await generate_perp_configs(drift_client) 109 | else: 110 | output = await generate_spot_configs(drift_client) 111 | 112 | print(output) 113 | await drift_client.unsubscribe() 114 | 115 | 116 | if __name__ == "__main__": 117 | asyncio.run(main()) 118 | -------------------------------------------------------------------------------- /scripts/math.sh: -------------------------------------------------------------------------------- 1 | pytest -v -s tests/math/*.py -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | git submodule update --init --recursive 2 | cd protocol-v2 3 | yarn && anchor build -------------------------------------------------------------------------------- /src/driftpy/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.8.56" 2 | -------------------------------------------------------------------------------- /src/driftpy/accounts/__init__.py: -------------------------------------------------------------------------------- 1 | from .get_accounts import * 2 | from .types import * 3 | -------------------------------------------------------------------------------- /src/driftpy/accounts/cache/__init__.py: -------------------------------------------------------------------------------- 1 | from .drift_client import * 2 | from .user import * 3 | -------------------------------------------------------------------------------- /src/driftpy/accounts/cache/user.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from anchorpy import Program 4 | from solders.pubkey import Pubkey 5 | from solana.rpc.commitment import Commitment 6 | 7 | from driftpy.accounts import get_user_account_and_slot 8 | from driftpy.accounts import UserAccountSubscriber, DataAndSlot 9 | from driftpy.types import UserAccount 10 | 11 | 12 | class CachedUserAccountSubscriber(UserAccountSubscriber): 13 | def __init__( 14 | self, 15 | user_pubkey: Pubkey, 16 | program: Program, 17 | commitment: Commitment = "confirmed", 18 | ): 19 | self.program = program 20 | self.commitment = commitment 21 | self.user_pubkey = user_pubkey 22 | self.user_and_slot = None 23 | 24 | async def subscribe(self): 25 | await self.update_cache() 26 | 27 | async def update_cache(self): 28 | user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey) 29 | self.user_and_slot = user_and_slot 30 | 31 | async def fetch(self): 32 | await self.update_cache() 33 | 34 | def update_data(self, data: Optional[DataAndSlot[UserAccount]]): 35 | if data is not None: 36 | if self.user_and_slot is None or data.slot >= self.user_and_slot.slot: 37 | self.user_and_slot = data 38 | 39 | def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]: 40 | return self.user_and_slot 41 | 42 | def unsubscribe(self): 43 | self.user_and_slot = None 44 | -------------------------------------------------------------------------------- /src/driftpy/accounts/demo/__init__.py: -------------------------------------------------------------------------------- 1 | from .drift_client import * 2 | from .user import * 3 | -------------------------------------------------------------------------------- /src/driftpy/accounts/demo/user.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from anchorpy import Program 4 | from solders.pubkey import Pubkey 5 | from solana.rpc.commitment import Commitment 6 | 7 | from driftpy.accounts import get_user_account_and_slot 8 | from driftpy.accounts import UserAccountSubscriber, DataAndSlot 9 | from driftpy.types import UserAccount 10 | 11 | 12 | class DemoUserAccountSubscriber(UserAccountSubscriber): 13 | def __init__( 14 | self, 15 | user_pubkey: Pubkey, 16 | program: Program, 17 | commitment: Commitment = "confirmed", 18 | ): 19 | self.program = program 20 | self.commitment = commitment 21 | self.user_pubkey = user_pubkey 22 | self.user_and_slot = None 23 | 24 | async def subscribe(self): 25 | await self.update_cache() 26 | 27 | async def update_cache(self): 28 | user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey) 29 | self.user_and_slot = user_and_slot 30 | 31 | async def fetch(self): 32 | await self.update_cache() 33 | 34 | def update_data(self, data: Optional[DataAndSlot[UserAccount]]): 35 | if data is not None: 36 | if self.user_and_slot is None or data.slot >= self.user_and_slot.slot: 37 | self.user_and_slot = data 38 | 39 | def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]: 40 | return self.user_and_slot 41 | 42 | def unsubscribe(self): 43 | self.user_and_slot = None 44 | -------------------------------------------------------------------------------- /src/driftpy/accounts/grpc/geyser_codegen/README.md: -------------------------------------------------------------------------------- 1 | # Note for driftpy: 2 | 3 | These generated `geyser_pb2` files were generated following 4 | the instructions on https://github.com/jito-labs/jito-python/ and committed here. 5 | 6 | Original readme below: 7 | 8 | 9 | # Python example 10 | 11 | ## DISCLAIMER 12 | 13 | This example can contains errors or be behind of the latest stable version, please use it only as an example of how your subscription can looks like. If you want well tested production ready example, please check our implementation on Rust. 14 | 15 |
16 | 17 | ## Instruction 18 | 19 | This Python example uses [grpc.io](https://grpc.io/) library. 20 | It assumes your CA trust store on your machine allows trust the CA from your RPC endpoint. 21 | 22 | ## Installation 23 | 24 | Create a virtual environment and install its dependencies: 25 | ```bash 26 | $ python -m venv venv 27 | $ . venv/bin/activate 28 | (venv) $ python -m pip install -U pip 29 | (venv) $ python -m pip install -r requirements.txt 30 | ``` 31 | 32 | ## Launch the helloworld_geyser 33 | 34 | Print the usage: 35 | 36 | ```bash 37 | (venv) $ python helloworld_geyser.py --help 38 | Usage: helloworld_geyser.py [OPTIONS] 39 | 40 | Simple program to get the latest solana slot number 41 | 42 | Options: 43 | --rpc-fqdn TEXT Fully Qualified domain name of your RPC endpoint 44 | --x-token TEXT x-token to authenticate each gRPC call 45 | --help Show this message and exit. 46 | ``` 47 | 48 | - `rpc-fqdn`: is the fully qualified domain name without the `https://`, such as `index.rpcpool.com`. 49 | - `x-token` : is the x-token to authenticate yourself to the RPC node. 50 | 51 | Here is a full example: 52 | 53 | ```bash 54 | (venv) $ python helloworld_geyser.py --rpc-fqdn 'index.rpcpool.com' --x-token '2625ae71-0823-41b3-b3bc-4ff89d762d52' 55 | slot: 264236514 56 | 57 | ``` 58 | 59 | **NOTE**: `2625ae71-0823-41b3-b3bc-4ff89d762d52` is a fake x-token, you need to provide your own token. 60 | 61 | ## Generate gRPC service and request signatures 62 | 63 | The library `grpcio` generates the stubs for you. 64 | 65 | From the directory of `helloword_geyser.py` you can generate all the stubs and data types using the following command: 66 | 67 | ```bash 68 | (venv) $ python -m grpc_tools.protoc -I../../yellowstone-grpc-proto/proto/ --python_out=. --pyi_out=. --grpc_python_out=. ../../yellowstone-grpc-proto/proto/* 69 | ``` 70 | 71 | This will generate: 72 | - geyser_pb2.py 73 | - geyser_pb2.pyi 74 | - geyser_pb2_grpc.py 75 | - solana_storage_pb2.py 76 | - solana_storage_pb2.pyi 77 | - solana_storage_pb2_grpc.py 78 | 79 | Which you can then import into your code. 80 | 81 | 82 | ## Useful documentation for grpcio authentication process 83 | 84 | - [secure_channel](https://grpc.github.io/grpc/python/grpc.html#create-client-credentials) 85 | 86 | - [extend auth method via call credentials](https://grpc.io/docs/guides/auth/#extending-grpc-to-support-other-authentication-mechanisms) 87 | -------------------------------------------------------------------------------- /src/driftpy/accounts/grpc/geyser_codegen/solana_storage_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | """Client and server classes corresponding to protobuf-defined services.""" 3 | import grpc 4 | import warnings 5 | 6 | 7 | GRPC_GENERATED_VERSION = '1.68.1' 8 | GRPC_VERSION = grpc.__version__ 9 | _version_not_supported = False 10 | 11 | try: 12 | from grpc._utilities import first_version_is_lower 13 | _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) 14 | except ImportError: 15 | _version_not_supported = True 16 | 17 | if _version_not_supported: 18 | raise RuntimeError( 19 | f'The grpc package installed is at version {GRPC_VERSION},' 20 | + f' but the generated code in solana_storage_pb2_grpc.py depends on' 21 | + f' grpcio>={GRPC_GENERATED_VERSION}.' 22 | + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' 23 | + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' 24 | ) 25 | -------------------------------------------------------------------------------- /src/driftpy/accounts/grpc/geyser_codegen/subscribe_geyser.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from typing import Iterator, Optional 5 | 6 | import geyser_pb2 7 | import geyser_pb2_grpc 8 | import grpc 9 | 10 | 11 | def _triton_sign_request( 12 | callback: grpc.AuthMetadataPluginCallback, 13 | x_token: Optional[str], 14 | error: Optional[Exception], 15 | ): 16 | metadata = (("x-token", x_token),) 17 | return callback(metadata, error) 18 | 19 | 20 | class TritonAuthMetadataPlugin(grpc.AuthMetadataPlugin): 21 | def __init__(self, x_token: str): 22 | self.x_token = x_token 23 | 24 | def __call__( 25 | self, 26 | context: grpc.AuthMetadataContext, 27 | callback: grpc.AuthMetadataPluginCallback, 28 | ): 29 | return _triton_sign_request(callback, self.x_token, None) 30 | 31 | 32 | def create_subscribe_request() -> Iterator[geyser_pb2.SubscribeRequest]: 33 | request = geyser_pb2.SubscribeRequest() 34 | 35 | # Create the account filter 36 | account_filter = geyser_pb2.SubscribeRequestFilterAccounts() 37 | 38 | # Add a specific account to monitor 39 | # Note: This needs to be the base58 encoded public key 40 | account_filter.account.append("dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH") 41 | account_filter.account.append("DRiP8pChV3hr8FkdgPpxpVwQh3dHnTzHgdbn5Z3fmwHc") 42 | account_filter.account.append("Fe4hMZrg7R97ZrbSScWBXUpQwZB9gzBnhodTCGyjkHsG") 43 | account_filter.nonempty_txn_signature = True 44 | 45 | # Copy the filter into the request 46 | request.accounts["account_monitor"].CopyFrom(account_filter) 47 | request.commitment = geyser_pb2.CommitmentLevel.CONFIRMED 48 | 49 | yield request 50 | 51 | # Keep connection alive with pings 52 | while True: 53 | time.sleep(30) 54 | ping_request = geyser_pb2.SubscribeRequest() 55 | ping_request.ping.id = int(time.time()) 56 | yield ping_request 57 | 58 | 59 | def handle_subscribe_updates(stub: geyser_pb2_grpc.GeyserStub): 60 | """ 61 | Handles the streaming updates from the subscription. 62 | Each update can contain different types of data based on our filters. 63 | """ 64 | try: 65 | request_iterator = create_subscribe_request() 66 | print("Starting subscription stream...") 67 | 68 | response_stream = stub.Subscribe(request_iterator) 69 | 70 | for update in response_stream: 71 | if update.HasField("account"): 72 | print("\nAccount Update:") 73 | print(f" Pubkey: {update.account.account.pubkey.hex()}") 74 | print(f" Owner: {update.account.account.owner.hex()}") 75 | print(f" Lamports: {update.account.account.lamports}") 76 | print(f" Slot: {update.account.slot}") 77 | if update.account.account.txn_signature: 78 | print( 79 | f" Transaction: {update.account.account.txn_signature.hex()}" 80 | ) 81 | 82 | elif update.HasField("pong"): 83 | print(f"Received pong: {update.pong.id}") 84 | 85 | except grpc.RpcError as e: 86 | logging.error(f"RPC error occurred: {str(e)}") 87 | raise 88 | 89 | 90 | def run_subscription_client(): 91 | rpc_fqdn = os.environ.get("RPC_FDQN") 92 | x_token = os.environ.get("X_TOKEN") 93 | 94 | auth = TritonAuthMetadataPlugin(x_token) 95 | ssl_creds = grpc.ssl_channel_credentials() 96 | call_creds = grpc.metadata_call_credentials(auth) 97 | combined_creds = grpc.composite_channel_credentials(ssl_creds, call_creds) 98 | 99 | with grpc.secure_channel(rpc_fqdn, credentials=combined_creds) as channel: 100 | stub = geyser_pb2_grpc.GeyserStub(channel) 101 | handle_subscribe_updates(stub) 102 | 103 | 104 | if __name__ == "__main__": 105 | logging.basicConfig(level=logging.INFO) 106 | run_subscription_client() 107 | -------------------------------------------------------------------------------- /src/driftpy/accounts/grpc/user.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from driftpy.accounts.grpc.account_subscriber import GrpcAccountSubscriber 4 | from driftpy.accounts.types import DataAndSlot, UserAccountSubscriber 5 | from driftpy.types import UserAccount 6 | 7 | 8 | class GrpcUserAccountSubscriber( 9 | GrpcAccountSubscriber[UserAccount], UserAccountSubscriber 10 | ): 11 | def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]: 12 | return self.data_and_slot 13 | -------------------------------------------------------------------------------- /src/driftpy/accounts/grpc/user_stats.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from driftpy.accounts.grpc.account_subscriber import GrpcAccountSubscriber 4 | from driftpy.accounts.types import DataAndSlot, UserStatsAccountSubscriber 5 | from driftpy.types import UserStatsAccount 6 | 7 | 8 | class GrpcUserStatsAccountSubscriber( 9 | GrpcAccountSubscriber[UserStatsAccount], UserStatsAccountSubscriber 10 | ): 11 | def get_user_stats_account_and_slot( 12 | self, 13 | ) -> Optional[DataAndSlot[UserStatsAccount]]: 14 | return self.data_and_slot 15 | -------------------------------------------------------------------------------- /src/driftpy/accounts/polling/__init__.py: -------------------------------------------------------------------------------- 1 | from .drift_client import * 2 | from .user import * 3 | -------------------------------------------------------------------------------- /src/driftpy/accounts/polling/user.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from anchorpy import Program 4 | from solders.pubkey import Pubkey 5 | 6 | from driftpy.accounts import ( 7 | UserAccountSubscriber, 8 | DataAndSlot, 9 | get_user_account_and_slot, 10 | ) 11 | 12 | from driftpy.accounts.bulk_account_loader import BulkAccountLoader 13 | from driftpy.types import UserAccount 14 | 15 | 16 | class PollingUserAccountSubscriber(UserAccountSubscriber): 17 | def __init__( 18 | self, 19 | user_account_pubkey: Pubkey, 20 | program: Program, 21 | bulk_account_loader: BulkAccountLoader, 22 | ): 23 | self.bulk_account_loader = bulk_account_loader 24 | self.program = program 25 | self.user_account_pubkey = user_account_pubkey 26 | self.data_and_slot: Optional[DataAndSlot[UserAccount]] = None 27 | self.decode = self.program.coder.accounts.decode 28 | self.callback_id = None 29 | 30 | async def subscribe(self): 31 | if self.callback_id is not None: 32 | return 33 | 34 | self.add_to_account_loader() 35 | 36 | if self.data_and_slot is None: 37 | await self.fetch() 38 | 39 | def add_to_account_loader(self): 40 | if self.callback_id is not None: 41 | return 42 | 43 | self.callback_id = self.bulk_account_loader.add_account( 44 | self.user_account_pubkey, self._account_loader_callback 45 | ) 46 | 47 | def _account_loader_callback(self, buffer: bytes, slot: int): 48 | if buffer is None: 49 | return 50 | 51 | if self.data_and_slot is not None and self.data_and_slot.slot >= slot: 52 | return 53 | 54 | account = self.decode(buffer) 55 | self.data_and_slot = DataAndSlot(slot, account) 56 | 57 | async def fetch(self): 58 | data_and_slot = await get_user_account_and_slot( 59 | self.program, self.user_account_pubkey 60 | ) 61 | self.update_data(data_and_slot) 62 | 63 | def update_data(self, new_data: Optional[DataAndSlot[UserAccount]]): 64 | if new_data is None: 65 | return 66 | 67 | if self.data_and_slot is None or new_data.slot >= self.data_and_slot.slot: 68 | self.data_and_slot = new_data 69 | 70 | def unsubscribe(self): 71 | if self.callback_id is None: 72 | return 73 | 74 | self.bulk_account_loader.remove_account( 75 | self.user_account_pubkey, self.callback_id 76 | ) 77 | 78 | self.callback_id = None 79 | 80 | def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]: 81 | return self.data_and_slot 82 | -------------------------------------------------------------------------------- /src/driftpy/accounts/types.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Awaitable, Callable, Generic, Optional, Sequence, TypeVar, Union 4 | 5 | from solana.rpc.commitment import Commitment 6 | from solana.rpc.types import MemcmpOpts 7 | from solders.pubkey import Pubkey 8 | 9 | from driftpy.types import ( 10 | OraclePriceData, 11 | OracleSource, 12 | PerpMarketAccount, 13 | SpotMarketAccount, 14 | StateAccount, 15 | UserAccount, 16 | UserStatsAccount, 17 | ) 18 | 19 | T = TypeVar("T") 20 | 21 | 22 | @dataclass 23 | class DataAndSlot(Generic[T]): 24 | slot: int 25 | data: T 26 | 27 | 28 | @dataclass 29 | class FullOracleWrapper: 30 | pubkey: Pubkey 31 | oracle_source: OracleSource 32 | oracle_price_data_and_slot: Optional[DataAndSlot[OraclePriceData]] 33 | 34 | 35 | @dataclass 36 | class BufferAndSlot: 37 | slot: int 38 | buffer: bytes 39 | 40 | 41 | @dataclass 42 | class WebsocketProgramAccountOptions: 43 | filters: Sequence[MemcmpOpts] 44 | commitment: Commitment 45 | encoding: str 46 | 47 | 48 | @dataclass 49 | class GrpcProgramAccountOptions: 50 | filters: Sequence[MemcmpOpts] 51 | commitment: Commitment 52 | 53 | 54 | UpdateCallback = Callable[[str, DataAndSlot[UserAccount]], Awaitable[None]] 55 | 56 | MarketUpdateCallback = Callable[ 57 | [str, DataAndSlot[Union[PerpMarketAccount, SpotMarketAccount]]], Awaitable[None] 58 | ] 59 | 60 | 61 | class DriftClientAccountSubscriber: 62 | @abstractmethod 63 | async def subscribe(self): 64 | pass 65 | 66 | @abstractmethod 67 | def unsubscribe(self): 68 | pass 69 | 70 | @abstractmethod 71 | async def fetch(self): 72 | pass 73 | 74 | @abstractmethod 75 | def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]: 76 | pass 77 | 78 | @abstractmethod 79 | def get_perp_market_and_slot( 80 | self, market_index: int 81 | ) -> Optional[DataAndSlot[PerpMarketAccount]]: 82 | pass 83 | 84 | @abstractmethod 85 | def get_spot_market_and_slot( 86 | self, market_index: int 87 | ) -> Optional[DataAndSlot[SpotMarketAccount]]: 88 | pass 89 | 90 | @abstractmethod 91 | def get_oracle_price_data_and_slot( 92 | self, oracle: Pubkey 93 | ) -> Optional[DataAndSlot[OraclePriceData]]: 94 | pass 95 | 96 | @abstractmethod 97 | def get_market_accounts_and_slots(self) -> list[DataAndSlot[PerpMarketAccount]]: 98 | pass 99 | 100 | @abstractmethod 101 | def get_spot_market_accounts_and_slots( 102 | self, 103 | ) -> list[DataAndSlot[SpotMarketAccount]]: 104 | pass 105 | 106 | 107 | class UserAccountSubscriber: 108 | @abstractmethod 109 | async def subscribe(self): 110 | pass 111 | 112 | @abstractmethod 113 | def unsubscribe(self): 114 | pass 115 | 116 | @abstractmethod 117 | async def fetch(self): 118 | pass 119 | 120 | async def update_data(self, data: Optional[DataAndSlot[UserAccount]]): 121 | pass 122 | 123 | @abstractmethod 124 | def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]: 125 | pass 126 | 127 | 128 | class UserStatsAccountSubscriber: 129 | @abstractmethod 130 | async def subscribe(self): 131 | pass 132 | 133 | @abstractmethod 134 | def unsubscribe(self): 135 | pass 136 | 137 | @abstractmethod 138 | async def fetch(self): 139 | pass 140 | 141 | @abstractmethod 142 | def get_user_stats_account_and_slot( 143 | self, 144 | ) -> Optional[DataAndSlot[UserStatsAccount]]: 145 | pass 146 | -------------------------------------------------------------------------------- /src/driftpy/accounts/ws/__init__.py: -------------------------------------------------------------------------------- 1 | from .drift_client import * 2 | from .user import * 3 | from .user_stats import * 4 | -------------------------------------------------------------------------------- /src/driftpy/accounts/ws/account_subscriber.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Callable, Generic, Optional, TypeVar, cast 3 | 4 | import websockets 5 | import websockets.exceptions # force eager imports 6 | from anchorpy.program.core import Program 7 | from solana.rpc.commitment import Commitment 8 | from solana.rpc.websocket_api import SolanaWsClientProtocol, connect 9 | from solders.pubkey import Pubkey 10 | 11 | from driftpy.accounts import ( 12 | DataAndSlot, 13 | UserAccountSubscriber, 14 | UserStatsAccountSubscriber, 15 | get_account_data_and_slot, 16 | ) 17 | from driftpy.types import get_ws_url 18 | 19 | T = TypeVar("T") 20 | 21 | 22 | class WebsocketAccountSubscriber( 23 | UserAccountSubscriber, UserStatsAccountSubscriber, Generic[T] 24 | ): 25 | def __init__( 26 | self, 27 | pubkey: Pubkey, 28 | program: Program, 29 | commitment: Commitment = Commitment("confirmed"), 30 | decode: Optional[Callable[[bytes], T]] = None, 31 | initial_data: Optional[DataAndSlot] = None, 32 | ): 33 | self.program = program 34 | self.commitment = commitment 35 | self.pubkey = pubkey 36 | self.data_and_slot = initial_data or None 37 | self.task = None 38 | self.decode = ( 39 | decode if decode is not None else self.program.coder.accounts.decode 40 | ) 41 | self.ws: Optional[SolanaWsClientProtocol] = None 42 | 43 | async def subscribe(self): 44 | if self.data_and_slot is None: 45 | await self.fetch() 46 | 47 | self.task = asyncio.create_task(self.subscribe_ws()) 48 | return self.task 49 | 50 | def is_subscribed(self): 51 | return self.task is not None 52 | 53 | async def subscribe_ws(self): 54 | endpoint = self.program.provider.connection._provider.endpoint_uri 55 | ws_endpoint = get_ws_url(endpoint) 56 | 57 | async for ws in connect(ws_endpoint): 58 | try: 59 | self.ws = cast(SolanaWsClientProtocol, ws) 60 | await self.ws.account_subscribe( 61 | self.pubkey, 62 | commitment=self.commitment, 63 | encoding="base64", 64 | ) 65 | first_resp = await ws.recv() 66 | subscription_id = cast(int, first_resp[0].result) 67 | 68 | async for msg in ws: 69 | try: 70 | slot = int(msg[0].result.context.slot) # type: ignore 71 | 72 | if msg[0].result.value is None: 73 | continue 74 | 75 | account_bytes = cast(bytes, msg[0].result.value.data) # type: ignore 76 | decoded_data = self.decode(account_bytes) 77 | self.update_data(DataAndSlot(slot, decoded_data)) 78 | except Exception: 79 | print("Error processing account data") 80 | break 81 | await self.ws.account_unsubscribe(subscription_id) 82 | except websockets.exceptions.ConnectionClosed: 83 | print("Websocket closed, reconnecting...") 84 | continue 85 | 86 | async def fetch(self): 87 | new_data = await get_account_data_and_slot( 88 | self.pubkey, self.program, self.commitment, self.decode 89 | ) 90 | self.update_data(new_data) 91 | 92 | def update_data(self, new_data: Optional[DataAndSlot[T]]): 93 | if new_data is None: 94 | return 95 | 96 | if self.data_and_slot is None or new_data.slot >= self.data_and_slot.slot: 97 | self.data_and_slot = new_data 98 | 99 | async def unsubscribe(self): 100 | if self.task: 101 | self.task.cancel() 102 | self.task = None 103 | if self.ws: 104 | await self.ws.close() 105 | self.ws = None 106 | -------------------------------------------------------------------------------- /src/driftpy/accounts/ws/user.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from driftpy.accounts import DataAndSlot 4 | from driftpy.types import UserAccount 5 | 6 | from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber 7 | from driftpy.accounts.types import UserAccountSubscriber 8 | 9 | 10 | class WebsocketUserAccountSubscriber( 11 | WebsocketAccountSubscriber[UserAccount], UserAccountSubscriber 12 | ): 13 | def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]: 14 | return self.data_and_slot 15 | -------------------------------------------------------------------------------- /src/driftpy/accounts/ws/user_stats.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from driftpy.accounts import DataAndSlot 4 | from driftpy.accounts.types import UserStatsAccountSubscriber 5 | from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber 6 | from driftpy.types import UserStatsAccount 7 | 8 | 9 | class WebsocketUserStatsAccountSubscriber( 10 | WebsocketAccountSubscriber[UserStatsAccount], UserStatsAccountSubscriber 11 | ): 12 | def get_user_stats_account_and_slot( 13 | self, 14 | ) -> Optional[DataAndSlot[UserStatsAccount]]: 15 | return self.data_and_slot 16 | -------------------------------------------------------------------------------- /src/driftpy/address_lookup_table.py: -------------------------------------------------------------------------------- 1 | from solana.rpc.async_api import AsyncClient 2 | from solders.pubkey import Pubkey 3 | from solders.address_lookup_table_account import AddressLookupTableAccount 4 | 5 | LOOKUP_TABLE_META_SIZE = 56 6 | 7 | 8 | async def get_address_lookup_table( 9 | connection: AsyncClient, pubkey: Pubkey 10 | ) -> AddressLookupTableAccount: 11 | account_info = await connection.get_account_info(pubkey) 12 | return decode_address_lookup_table(pubkey, account_info.value.data) 13 | 14 | 15 | def decode_address_lookup_table(pubkey: Pubkey, data: bytes): 16 | data_len = len(data) 17 | 18 | addresses = [] 19 | i = LOOKUP_TABLE_META_SIZE 20 | while i < data_len: 21 | addresses.append(Pubkey.from_bytes(data[i : i + 32])) 22 | i += 32 23 | 24 | return AddressLookupTableAccount(pubkey, addresses) 25 | -------------------------------------------------------------------------------- /src/driftpy/auction_subscriber/auction_subscriber.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Optional 3 | 4 | from events.events import Events, _EventSlot 5 | from solders.pubkey import Pubkey 6 | 7 | from driftpy.accounts.types import DataAndSlot, WebsocketProgramAccountOptions 8 | from driftpy.accounts.ws.program_account_subscriber import ( 9 | WebSocketProgramAccountSubscriber, 10 | ) 11 | from driftpy.auction_subscriber.types import AuctionSubscriberConfig 12 | from driftpy.decode.user import decode_user 13 | from driftpy.memcmp import get_user_filter, get_user_with_auction_filter 14 | from driftpy.types import UserAccount 15 | 16 | 17 | class AuctionEvents(Events): 18 | """ 19 | Subclass to explicitly declare the on_account_update event 20 | both in __events__ and as a typed attribute. 21 | """ 22 | 23 | __events__ = ("on_account_update",) 24 | on_account_update: _EventSlot 25 | 26 | 27 | class AuctionSubscriber: 28 | def __init__(self, config: AuctionSubscriberConfig): 29 | self.drift_client = config.drift_client 30 | self.commitment = ( 31 | config.commitment 32 | if config.commitment is not None 33 | else self.drift_client.connection.commitment 34 | ) 35 | self.resub_timeout_ms = config.resub_timeout_ms 36 | self.subscriber: Optional[WebSocketProgramAccountSubscriber] = None 37 | self.event_emitter = AuctionEvents() 38 | 39 | async def on_update(self, account_pubkey: str, data: DataAndSlot[UserAccount]): 40 | self.event_emitter.on_account_update( 41 | data.data, 42 | Pubkey.from_string(account_pubkey), 43 | data.slot, # type: ignore 44 | ) 45 | 46 | async def subscribe(self): 47 | if self.subscriber is None: 48 | filters = (get_user_filter(), get_user_with_auction_filter()) 49 | options = WebsocketProgramAccountOptions(filters, self.commitment, "base64") 50 | self.subscriber = WebSocketProgramAccountSubscriber( 51 | "AuctionSubscriber", 52 | self.drift_client.program, 53 | options, 54 | self.on_update, 55 | decode_user, 56 | self.resub_timeout_ms, 57 | ) 58 | 59 | if self.subscriber.subscribed: 60 | return 61 | 62 | await self.subscriber.subscribe() 63 | 64 | def unsubscribe(self): 65 | if self.subscriber is None: 66 | return 67 | asyncio.create_task(self.subscriber.unsubscribe()) 68 | -------------------------------------------------------------------------------- /src/driftpy/auction_subscriber/grpc_auction_subscriber.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from events.events import Events, _EventSlot 4 | from solders.pubkey import Pubkey 5 | 6 | from driftpy.accounts.grpc.account_subscriber import GrpcAccountSubscriber 7 | from driftpy.accounts.grpc.program_account_subscriber import ( 8 | GrpcProgramAccountSubscriber, 9 | ) 10 | from driftpy.accounts.types import DataAndSlot, GrpcProgramAccountOptions 11 | from driftpy.auction_subscriber.types import GrpcAuctionSubscriberConfig 12 | from driftpy.decode.user import decode_user 13 | from driftpy.memcmp import get_user_filter, get_user_with_auction_filter 14 | from driftpy.types import UserAccount 15 | 16 | 17 | class GrpcAuctionEvents(Events): 18 | __events__ = ("on_account_update",) 19 | on_account_update: _EventSlot 20 | 21 | 22 | class GrpcAuctionSubscriber: 23 | def __init__(self, config: GrpcAuctionSubscriberConfig): 24 | self.config = config 25 | self.drift_client = config.drift_client 26 | self.commitment = ( 27 | config.commitment 28 | if config.commitment is not None 29 | else self.drift_client.connection.commitment 30 | ) 31 | self.subscribers: list[GrpcAccountSubscriber] = [] 32 | self.event_emitter = GrpcAuctionEvents() 33 | 34 | async def on_update(self, account_pubkey: str, data: DataAndSlot[UserAccount]): 35 | self.event_emitter.on_account_update( 36 | data.data, Pubkey.from_string(account_pubkey), data.slot 37 | ) 38 | 39 | async def subscribe(self): 40 | if self.subscribers: 41 | return 42 | 43 | filters = (get_user_filter(), get_user_with_auction_filter()) 44 | options = GrpcProgramAccountOptions(filters, self.commitment) 45 | self.subscriber = GrpcProgramAccountSubscriber( 46 | grpc_config=self.config.grpc_config, 47 | subscription_name="AuctionSubscriber", 48 | program=self.drift_client.program, 49 | options=options, 50 | on_update=self.on_update, 51 | decode=decode_user, 52 | ) 53 | await self.subscriber.subscribe() 54 | 55 | def unsubscribe(self): 56 | if not self.subscribers: 57 | return 58 | 59 | for subscriber in self.subscribers: 60 | asyncio.create_task(subscriber.unsubscribe()) 61 | self.subscribers.clear() 62 | -------------------------------------------------------------------------------- /src/driftpy/auction_subscriber/types.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from solana.rpc.commitment import Commitment 4 | 5 | from driftpy.drift_client import DriftClient 6 | from driftpy.types import GrpcConfig 7 | 8 | 9 | class AuctionSubscriberConfig: 10 | def __init__( 11 | self, 12 | drift_client: DriftClient, 13 | commitment: Optional[Commitment] = None, 14 | resub_timeout_ms: Optional[int] = None, 15 | ): 16 | self.drift_client = drift_client 17 | self.commitment = commitment 18 | self.resub_timeout_ms = resub_timeout_ms 19 | 20 | 21 | class GrpcAuctionSubscriberConfig(AuctionSubscriberConfig): 22 | def __init__( 23 | self, 24 | drift_client: DriftClient, 25 | grpc_config: GrpcConfig, 26 | commitment: Optional[Commitment] = None, 27 | ): 28 | super().__init__(drift_client, commitment) 29 | self.grpc_config = grpc_config 30 | -------------------------------------------------------------------------------- /src/driftpy/constants/__init__.py: -------------------------------------------------------------------------------- 1 | from driftpy.constants.numeric_constants import * 2 | -------------------------------------------------------------------------------- /src/driftpy/decode/pull_oracle.py: -------------------------------------------------------------------------------- 1 | from driftpy.decode.user import ( 2 | read_uint8, 3 | read_int32_le, 4 | read_bigint64le, 5 | ) 6 | from driftpy.types import PriceUpdateV2, PriceFeedMessage 7 | 8 | 9 | def decode_pull_oracle(buffer: bytes) -> PriceUpdateV2: 10 | offset = 8 11 | 12 | offset += 32 # skip 0 13 | 14 | verification_level_flag = read_uint8(buffer, offset) 15 | if verification_level_flag & 0x1: 16 | offset += 1 # skip verification_level Full 17 | else: 18 | offset += 2 # skip verificaton_level Partial 19 | 20 | offset += 32 # skip feed_id 21 | 22 | price = read_bigint64le(buffer, offset, True) 23 | offset += 8 24 | 25 | conf = read_bigint64le(buffer, offset, False) 26 | offset += 8 27 | 28 | exponent = read_int32_le(buffer, offset, True) 29 | offset += 4 30 | 31 | offset += 8 # skip publish_time 32 | offset += 8 # skip prev_publish_time 33 | 34 | ema_price = read_bigint64le(buffer, offset, True) 35 | offset += 8 36 | 37 | ema_conf = read_bigint64le(buffer, offset, False) 38 | offset += 8 39 | 40 | posted_slot = read_bigint64le(buffer, offset, False) 41 | 42 | price_feed_message = PriceFeedMessage(price, conf, exponent, ema_price, ema_conf) 43 | 44 | return PriceUpdateV2(price_feed_message, posted_slot) 45 | -------------------------------------------------------------------------------- /src/driftpy/decode/user_stat.py: -------------------------------------------------------------------------------- 1 | from solders.pubkey import Pubkey 2 | 3 | from driftpy.decode.user import ( 4 | read_bigint64le, 5 | read_int32_le, 6 | read_uint8, 7 | read_uint16_le, 8 | ) 9 | from driftpy.types import UserFees, UserStatsAccount 10 | 11 | 12 | def decode_user_stat(buffer: bytes) -> UserStatsAccount: 13 | offset = 8 14 | authority = Pubkey(buffer[offset : offset + 32]) 15 | offset += 32 16 | referrer = Pubkey(buffer[offset : offset + 32]) 17 | offset += 32 18 | 19 | total_fee_paid = read_bigint64le(buffer, offset, False) 20 | offset += 8 21 | total_fee_rebate = read_bigint64le(buffer, offset, False) 22 | offset += 8 23 | total_token_discount = read_bigint64le(buffer, offset, False) 24 | offset += 8 25 | total_referee_discount = read_bigint64le(buffer, offset, False) 26 | offset += 8 27 | total_referrer_reward = read_bigint64le(buffer, offset, False) 28 | offset += 8 29 | current_epoch_referrer_reward = read_bigint64le(buffer, offset, False) 30 | offset += 8 31 | 32 | user_fees = UserFees( 33 | total_fee_paid, 34 | total_fee_rebate, 35 | total_token_discount, 36 | total_referee_discount, 37 | total_referrer_reward, 38 | current_epoch_referrer_reward, 39 | ) 40 | 41 | next_epoch_ts = read_bigint64le(buffer, offset, True) 42 | offset += 8 43 | 44 | maker_volume_30d = read_bigint64le(buffer, offset, False) 45 | offset += 8 46 | 47 | taker_volume_30d = read_bigint64le(buffer, offset, False) 48 | offset += 8 49 | 50 | filler_volume_30d = read_bigint64le(buffer, offset, False) 51 | offset += 8 52 | 53 | last_maker_volume_30d_ts = read_bigint64le(buffer, offset, True) 54 | offset += 8 55 | 56 | last_taker_volume_30d_ts = read_bigint64le(buffer, offset, True) 57 | offset += 8 58 | 59 | last_filler_volume_30d_ts = read_bigint64le(buffer, offset, True) 60 | offset += 8 61 | 62 | if_staked_quote_asset_amount = read_bigint64le(buffer, offset, False) 63 | offset += 8 64 | 65 | number_of_sub_accounts = read_uint16_le(buffer, offset) 66 | offset += 2 67 | 68 | number_of_sub_accounts_created = read_uint16_le(buffer, offset) 69 | offset += 2 70 | 71 | referrer_status = read_uint8(buffer, offset) 72 | is_referrer = (referrer_status & 0x1) == 1 73 | offset += 1 74 | 75 | disable_update_perp_bid_ask_twap = read_uint8(buffer, offset) == 1 76 | offset += 1 77 | 78 | offset += 1 79 | 80 | fuel_overflow_status = read_uint8(buffer, offset) 81 | offset += 1 82 | 83 | fuel_insurance = read_int32_le(buffer, offset, False) 84 | offset += 4 85 | 86 | fuel_deposits = read_int32_le(buffer, offset, False) 87 | offset += 4 88 | 89 | fuel_borrows = read_int32_le(buffer, offset, False) 90 | offset += 4 91 | 92 | fuel_positions = read_int32_le(buffer, offset, False) 93 | offset += 4 94 | 95 | fuel_taker = read_int32_le(buffer, offset, False) 96 | offset += 4 97 | 98 | fuel_maker = read_int32_le(buffer, offset, False) 99 | offset += 4 100 | 101 | if_staked_gov_token_amount = read_bigint64le(buffer, offset, False) 102 | offset += 8 103 | 104 | last_fuel_if_bonus_update_ts = read_int32_le(buffer, offset, False) 105 | offset += 4 106 | 107 | padding = [0] * 12 108 | 109 | return UserStatsAccount( 110 | authority, 111 | referrer, 112 | user_fees, 113 | next_epoch_ts, 114 | maker_volume_30d, 115 | taker_volume_30d, 116 | filler_volume_30d, 117 | last_maker_volume_30d_ts, 118 | last_taker_volume_30d_ts, 119 | last_filler_volume_30d_ts, 120 | if_staked_quote_asset_amount, 121 | number_of_sub_accounts, 122 | number_of_sub_accounts_created, 123 | is_referrer, 124 | disable_update_perp_bid_ask_twap, 125 | fuel_overflow_status, 126 | fuel_insurance, 127 | fuel_deposits, 128 | fuel_borrows, 129 | fuel_positions, 130 | fuel_taker, 131 | fuel_maker, 132 | if_staked_gov_token_amount, 133 | last_fuel_if_bonus_update_ts, 134 | padding, 135 | ) 136 | -------------------------------------------------------------------------------- /src/driftpy/decode/utils.py: -------------------------------------------------------------------------------- 1 | def decode_name(bytes_list: list[int]): 2 | byte_array = bytes(bytes_list) 3 | return byte_array.decode("utf-8").strip() 4 | -------------------------------------------------------------------------------- /src/driftpy/dlob/client_types.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | 4 | from driftpy.dlob.dlob import DLOB 5 | from driftpy.drift_client import DriftClient 6 | 7 | 8 | class DLOBSource(ABC): 9 | @abstractmethod 10 | async def get_DLOB(self, slot: int) -> DLOB: 11 | pass 12 | 13 | 14 | class SlotSource(ABC): 15 | @abstractmethod 16 | def get_slot(self) -> int: 17 | pass 18 | 19 | 20 | @dataclass 21 | class DLOBClientConfig: 22 | drift_client: DriftClient 23 | dlob_source: DLOBSource 24 | slot_source: SlotSource 25 | update_frequency: int 26 | -------------------------------------------------------------------------------- /src/driftpy/dlob/dlob_helpers.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | from driftpy.types import ( 3 | MarketType, 4 | PerpMarketAccount, 5 | SpotMarketAccount, 6 | StateAccount, 7 | is_variant, 8 | ) 9 | 10 | 11 | def get_maker_rebate( 12 | market_type: MarketType, 13 | state_account: StateAccount, 14 | market_account: Union[PerpMarketAccount, SpotMarketAccount], 15 | ): 16 | if is_variant(market_type, "Perp"): 17 | maker_rebate_numerator = state_account.perp_fee_structure.fee_tiers[ 18 | 0 19 | ].maker_rebate_numerator 20 | maker_rebate_denominator = state_account.perp_fee_structure.fee_tiers[ 21 | 0 22 | ].maker_rebate_denominator 23 | else: 24 | maker_rebate_numerator = state_account.spot_fee_structure.fee_tiers[ 25 | 0 26 | ].maker_rebate_numerator 27 | maker_rebate_denominator = state_account.spot_fee_structure.fee_tiers[ 28 | 0 29 | ].maker_rebate_denominator 30 | 31 | fee_adjustment = ( 32 | market_account.fee_adjustment 33 | if market_account.fee_adjustment is not None 34 | else 0 35 | ) 36 | if fee_adjustment != 0: 37 | maker_rebate_numerator += (maker_rebate_numerator * fee_adjustment) // 100 38 | 39 | return maker_rebate_numerator, maker_rebate_denominator 40 | 41 | 42 | def get_node_lists(order_lists): 43 | from driftpy.dlob.dlob_node import MarketNodeLists 44 | 45 | order_lists: Dict[str, Dict[int, MarketNodeLists]] 46 | 47 | for _, node_lists in order_lists.get("perp", {}).items(): 48 | yield node_lists.resting_limit["bid"] 49 | yield node_lists.resting_limit["ask"] 50 | yield node_lists.taking_limit["bid"] 51 | yield node_lists.taking_limit["ask"] 52 | yield node_lists.market["bid"] 53 | yield node_lists.market["ask"] 54 | yield node_lists.floating_limit["bid"] 55 | yield node_lists.floating_limit["ask"] 56 | yield node_lists.trigger["above"] 57 | yield node_lists.trigger["below"] 58 | 59 | for _, node_lists in order_lists.get("spot", {}).items(): 60 | yield node_lists.resting_limit["bid"] 61 | yield node_lists.resting_limit["ask"] 62 | yield node_lists.taking_limit["bid"] 63 | yield node_lists.taking_limit["ask"] 64 | yield node_lists.market["bid"] 65 | yield node_lists.market["ask"] 66 | yield node_lists.floating_limit["bid"] 67 | yield node_lists.floating_limit["ask"] 68 | yield node_lists.trigger["above"] 69 | yield node_lists.trigger["below"] 70 | -------------------------------------------------------------------------------- /src/driftpy/drift_user_stats.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from solders.pubkey import Pubkey 5 | from solana.rpc.commitment import Commitment, Confirmed 6 | 7 | from driftpy.accounts.types import DataAndSlot 8 | from driftpy.types import ReferrerInfo, UserStatsAccount 9 | from driftpy.accounts.ws.user_stats import WebsocketUserStatsAccountSubscriber 10 | from driftpy.addresses import ( 11 | get_user_account_public_key, 12 | get_user_stats_account_public_key, 13 | ) 14 | 15 | 16 | @dataclass 17 | class UserStatsSubscriptionConfig: 18 | commitment: Commitment = Confirmed 19 | resub_timeout_ms: Optional[int] = None 20 | initial_data: Optional[DataAndSlot[UserStatsAccount]] = None 21 | 22 | 23 | class DriftUserStats: 24 | def __init__( 25 | self, 26 | drift_client, 27 | user_stats_account_pubkey: Pubkey, 28 | config: UserStatsSubscriptionConfig, 29 | ): 30 | self.drift_client = drift_client 31 | self.user_stats_account_pubkey = user_stats_account_pubkey 32 | self.account_subscriber = WebsocketUserStatsAccountSubscriber( 33 | user_stats_account_pubkey, 34 | drift_client.program, 35 | config.commitment, 36 | initial_data=config.initial_data, 37 | ) 38 | self.subscribed = False 39 | 40 | async def subscribe(self) -> bool: 41 | if self.subscribed: 42 | return 43 | 44 | await self.account_subscriber.subscribe() 45 | self.subscribed = True 46 | 47 | return self.subscribed 48 | 49 | async def fetch_accounts(self): 50 | await self.account_subscriber.fetch() 51 | 52 | def unsubscribe(self): 53 | self.account_subscriber.unsubscribe() 54 | 55 | def get_account_and_slot(self) -> DataAndSlot[UserStatsAccount]: 56 | return self.account_subscriber.get_user_stats_account_and_slot() 57 | 58 | def get_account(self) -> UserStatsAccount: 59 | return self.account_subscriber.get_user_stats_account_and_slot().data 60 | 61 | def get_referrer_info(self) -> Optional[ReferrerInfo]: 62 | if self.get_account().referrer == Pubkey.default(): 63 | return None 64 | else: 65 | return ReferrerInfo( 66 | get_user_account_public_key( 67 | self.drift_client.program_id, self.get_account().referrer, 0 68 | ), 69 | get_user_stats_account_public_key( 70 | self.drift_client.program_id, self.get_account().referrer 71 | ), 72 | ) 73 | -------------------------------------------------------------------------------- /src/driftpy/events/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/src/driftpy/events/__init__.py -------------------------------------------------------------------------------- /src/driftpy/events/event_list.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from dataclasses import dataclass 3 | 4 | from driftpy.events.types import WrappedEvent, EventSubscriptionOrderDirection, SortFn 5 | 6 | 7 | @dataclass 8 | class Node: 9 | event: WrappedEvent 10 | next: Optional[any] = None 11 | prev: Optional[any] = None 12 | 13 | 14 | class EventList: 15 | def __init__( 16 | self, 17 | max_size: int, 18 | sort_fn: SortFn, 19 | order_direction: EventSubscriptionOrderDirection, 20 | ): 21 | self.size = 0 22 | self.max_size = max_size 23 | self.sort_fn = sort_fn 24 | self.order_direction = order_direction 25 | self.head = None 26 | self.tail = None 27 | 28 | def insert(self, event: WrappedEvent) -> None: 29 | self.size += 1 30 | new_node = Node(event) 31 | if self.head is None: 32 | self.head = self.tail = new_node 33 | return 34 | 35 | halt_condition = -1 if self.order_direction == "asc" else 1 36 | 37 | if self.sort_fn(self.head.event, new_node.event) == halt_condition: 38 | self.head.prev = new_node 39 | new_node.next = self.head 40 | self.head = new_node 41 | else: 42 | current_node = self.head 43 | while ( 44 | current_node.next is not None 45 | and self.sort_fn(current_node.next.event, new_node.event) 46 | != halt_condition 47 | ): 48 | current_node = current_node.next 49 | 50 | new_node.next = current_node.next 51 | if current_node.next is not None: 52 | new_node.next.prev = new_node 53 | else: 54 | self.tail = new_node 55 | 56 | current_node.next = new_node 57 | new_node.prev = current_node 58 | 59 | if self.size > self.max_size: 60 | self.detach() 61 | 62 | def detach(self) -> None: 63 | node = self.tail 64 | if node.prev is not None: 65 | node.prev.next = node.next 66 | else: 67 | self.head = node.next 68 | 69 | if node.next is not None: 70 | node.next.prev = node.prev 71 | else: 72 | self.tail = node.prev 73 | 74 | self.size -= 1 75 | 76 | def to_array(self) -> list[WrappedEvent]: 77 | return list(self) 78 | 79 | def __iter__(self): 80 | node = self.head 81 | while node: 82 | yield node.event 83 | node = node.next 84 | -------------------------------------------------------------------------------- /src/driftpy/events/event_subscriber.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from anchorpy import EventParser, Program 4 | from events import Events as EventEmitter 5 | from solana.rpc.async_api import AsyncClient 6 | from solders.signature import Signature 7 | 8 | from driftpy.events.event_list import EventList 9 | from driftpy.events.parse import parse_logs 10 | from driftpy.events.sort import get_sort_fn 11 | from driftpy.events.tx_event_cache import TxEventCache 12 | from driftpy.events.types import EventSubscriptionOptions, EventType, WrappedEvent 13 | 14 | 15 | class EventSubscriber: 16 | def __init__( 17 | self, 18 | connection: AsyncClient, 19 | program: Program, 20 | options: EventSubscriptionOptions = EventSubscriptionOptions.default(), 21 | ): 22 | self.connection = connection 23 | self.program = program 24 | self.options = options 25 | self.subscribed = False 26 | self.event_list_map: dict[EventType:EventList] = {} 27 | for event_type in self.options.event_types: 28 | self.event_list_map[event_type] = EventList( 29 | self.options.max_events_per_type, 30 | get_sort_fn(self.options.order_by, self.options.order_dir), 31 | self.options.order_dir, 32 | ) 33 | self.event_parser = EventParser(self.program.program_id, self.program.coder) 34 | self.log_provider = self.options.get_log_provider(connection) 35 | self.tx_event_cache = TxEventCache(self.options.max_tx) 36 | self.event_emitter = EventEmitter(("new_event",)) 37 | 38 | def subscribe(self): 39 | self.log_provider.subscribe(self.handle_tx_logs) 40 | self.subscribed = True 41 | 42 | def unsubscribe(self): 43 | self.log_provider.unsubscribe() 44 | self.subscribed = False 45 | 46 | def handle_tx_logs( 47 | self, 48 | tx_sig: Signature, 49 | slot: int, 50 | logs: list[str], 51 | ): 52 | if self.tx_event_cache.has(str(tx_sig)): 53 | return 54 | 55 | wrapped_events = self.parse_events_from_logs(tx_sig, slot, logs) 56 | for wrapped_event in wrapped_events: 57 | self.event_list_map.get(wrapped_event.event_type).insert(wrapped_event) 58 | 59 | for wrapped_event in wrapped_events: 60 | self.event_emitter.new_event(wrapped_event) 61 | 62 | self.tx_event_cache.add(str(tx_sig), wrapped_events) 63 | 64 | def parse_events_from_logs(self, tx_sig: Signature, slot: int, logs: list[str]): 65 | wrapped_events = [] 66 | 67 | events = parse_logs(self.program, logs) 68 | 69 | for index, event in enumerate(events): 70 | if event.name in self.event_list_map: 71 | wrapped_event = WrappedEvent( 72 | event_type=event.name, 73 | tx_sig=tx_sig, 74 | slot=slot, 75 | tx_sig_index=index, 76 | data=event.data, 77 | ) 78 | wrapped_events.append(wrapped_event) 79 | 80 | return wrapped_events 81 | 82 | def get_event_list(self, event_type: EventType) -> Optional[EventList]: 83 | return self.event_list_map.get(event_type) 84 | 85 | def get_events_array(self, event_type: EventType) -> Optional[list[WrappedEvent]]: 86 | event_list = self.event_list_map.get(event_type) 87 | return None if event_list is None else event_list.to_array() 88 | 89 | def get_events_by_tx(self, tx_sig: str) -> Optional[list[WrappedEvent]]: 90 | return self.tx_event_cache.get(tx_sig) 91 | -------------------------------------------------------------------------------- /src/driftpy/events/fetch_logs.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import jsonrpcclient 4 | from solana.rpc.async_api import AsyncClient 5 | from solana.rpc.commitment import Commitment 6 | from solders.pubkey import Pubkey 7 | from solders.rpc.responses import ( 8 | RpcConfirmedTransactionStatusWithSignature, 9 | ) 10 | from solders.signature import Signature 11 | 12 | 13 | async def fetch_logs( 14 | connection: AsyncClient, 15 | address: Pubkey, 16 | commitment: Commitment, 17 | before_tx: Signature = None, 18 | until_tx: Signature = None, 19 | limit: int = None, 20 | batch_size: int = None, 21 | ) -> list[(Signature, int, list[str])]: 22 | response = await connection.get_signatures_for_address( 23 | address, before_tx, until_tx, limit, commitment 24 | ) 25 | 26 | if response.value is None: 27 | raise Exception("Error with get_signature_for_address") 28 | 29 | signatures = response.value 30 | 31 | sorted_signatures = sorted(signatures, key=lambda x: x.slot) 32 | 33 | filtered_signatures = list( 34 | filter(lambda signature: not signature.err, sorted_signatures) 35 | ) 36 | 37 | if len(filtered_signatures) == 0: 38 | return [] 39 | 40 | batch_size = batch_size if batch_size is not None else 25 41 | chunked_signatures = chunk(filtered_signatures, batch_size) 42 | 43 | chunked_transactions_logs = await asyncio.gather( 44 | *[ 45 | fetch_transactions(connection, signatures, commitment) 46 | for signatures in chunked_signatures 47 | ] 48 | ) 49 | 50 | return [ 51 | transaction_logs 52 | for transactions_logs in chunked_transactions_logs 53 | for transaction_logs in transactions_logs 54 | ] 55 | 56 | 57 | async def fetch_transactions( 58 | connection: AsyncClient, 59 | signatures: list[RpcConfirmedTransactionStatusWithSignature], 60 | commitment: Commitment, 61 | ) -> list[(Signature, int, list[str])]: 62 | rpc_requests = [] 63 | for signature in signatures: 64 | rpc_request = jsonrpcclient.request( 65 | "getTransaction", 66 | ( 67 | str(signature.signature), 68 | {"commitment": commitment, "maxSupportedTransactionVersion": 0}, 69 | ), 70 | ) 71 | rpc_requests.append(rpc_request) 72 | 73 | try: 74 | post = connection._provider.session.post( 75 | connection._provider.endpoint_uri, 76 | json=rpc_requests, 77 | headers={"content-encoding": "gzip"}, 78 | ) 79 | resp = await asyncio.wait_for(post, timeout=10) 80 | except asyncio.TimeoutError: 81 | print("request to rpc timed out") 82 | return [] 83 | 84 | parsed_resp = jsonrpcclient.parse(resp.json()) 85 | 86 | if isinstance(parsed_resp, jsonrpcclient.Error): 87 | raise ValueError(f"Error fetching transactions: {parsed_resp.message}") 88 | 89 | response = [] 90 | for rpc_result in parsed_resp: 91 | if not isinstance(rpc_result, jsonrpcclient.Ok): 92 | raise ValueError(f"Error fetching transactions - not ok: {rpc_result}") 93 | 94 | if rpc_result.result: 95 | tx_sig = rpc_result.result["transaction"]["signatures"][0] 96 | slot = rpc_result.result["slot"] 97 | logs = rpc_result.result["meta"]["logMessages"] 98 | response.append((tx_sig, slot, logs)) 99 | 100 | return response 101 | 102 | 103 | def chunk(array, size): 104 | return [array[i : i + size] for i in range(0, len(array), size)] 105 | -------------------------------------------------------------------------------- /src/driftpy/events/parse.py: -------------------------------------------------------------------------------- 1 | import binascii 2 | import re 3 | import base64 4 | 5 | from typing import Tuple, Optional 6 | from anchorpy import Program, Event 7 | 8 | DRIFT_PROGRAM_ID: str = "dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH" 9 | DRIFT_PROGRAM_START: str = f"Program {DRIFT_PROGRAM_ID} invoke" 10 | PROGRAM_LOG: str = "Program log: " 11 | PROGRAM_DATA: str = "Program data: " 12 | PROGRAM_LOG_START_INDEX: int = len(PROGRAM_LOG) 13 | PROGRAM_DATA_START_INDEX: int = len(PROGRAM_DATA) 14 | 15 | 16 | class ExecutionContext: 17 | def __init__(self): 18 | self.stack: list[str] = [] 19 | 20 | def program(self): 21 | if len(self.stack) == 0: 22 | raise ValueError("Expected the stack to have elements") 23 | return self.stack[-1] 24 | 25 | def push(self, program: str): 26 | self.stack.append(program) 27 | 28 | def pop(self): 29 | if len(self.stack) == 0: 30 | raise ValueError("Expected the stack to have elements") 31 | return self.stack.pop() 32 | 33 | 34 | def parse_logs(program: Program, logs: list[str]) -> list[Event]: 35 | events = [] 36 | execution = ExecutionContext() 37 | for log in logs: 38 | if log.startswith("Log truncated"): 39 | break 40 | 41 | event, new_program, did_pop = handle_log(execution, log, program) 42 | if event: 43 | events.append(event) 44 | if new_program: 45 | execution.push(new_program) 46 | if did_pop: 47 | execution.pop() 48 | 49 | return events 50 | 51 | 52 | def handle_log( 53 | execution: ExecutionContext, log: str, program: Program 54 | ) -> Tuple[Optional[Event], Optional[str], bool]: 55 | if len(execution.stack) > 0 and execution.program() == DRIFT_PROGRAM_ID: 56 | return handle_program_log(log, program) 57 | else: 58 | return (None, *handle_system_log(log)) 59 | 60 | 61 | def handle_program_log( 62 | log: str, program: Program 63 | ) -> Tuple[Optional[Event], Optional[str], bool]: 64 | if log.startswith(PROGRAM_LOG) or log.startswith(PROGRAM_DATA): 65 | log_str = ( 66 | log[PROGRAM_LOG_START_INDEX:] 67 | if log.startswith(PROGRAM_LOG) 68 | else log[PROGRAM_DATA_START_INDEX:] 69 | ) 70 | try: 71 | decoded = base64.b64decode(log_str) 72 | except binascii.Error: 73 | return (None, None, False) 74 | if len(decoded) < 8: 75 | return (None, None, False) 76 | event = program.coder.events.parse(decoded) 77 | return (event, None, False) 78 | else: 79 | return (None, *handle_system_log(log)) 80 | 81 | 82 | def handle_system_log(log: str) -> Tuple[Optional[str], bool]: 83 | log_start = log.split(":")[0] 84 | 85 | if re.findall(r"Program (.*) success", log_start) != []: 86 | return (None, True) 87 | elif log_start.startswith(DRIFT_PROGRAM_START): 88 | return (DRIFT_PROGRAM_ID, False) 89 | elif "invoke" in log_start: 90 | return ("cpi", False) 91 | else: 92 | return (None, False) 93 | -------------------------------------------------------------------------------- /src/driftpy/events/polling_log_provider.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from solana.rpc.async_api import AsyncClient 3 | from solana.rpc.commitment import Commitment 4 | 5 | from solders.pubkey import Pubkey 6 | 7 | from driftpy.events.fetch_logs import fetch_logs 8 | from driftpy.events.types import LogProvider, LogProviderCallback 9 | 10 | 11 | class PollingLogProvider(LogProvider): 12 | def __init__( 13 | self, 14 | connection: AsyncClient, 15 | address: Pubkey, 16 | commitment: Commitment, 17 | frequency: float, 18 | batch_size: int = 25, 19 | ): 20 | self.connection = connection 21 | self.address = address 22 | self.commitment = commitment 23 | self.frequency = frequency 24 | self.batch_size = batch_size 25 | self.task = None 26 | self.most_recent_tx = None 27 | 28 | def subscribe(self, callback: LogProviderCallback): 29 | if not self.is_subscribed(): 30 | 31 | async def fetch(): 32 | first_fetch = True 33 | while True: 34 | try: 35 | txs_logs = await fetch_logs( 36 | self.connection, 37 | self.address, 38 | self.commitment, 39 | None, 40 | self.most_recent_tx, 41 | 1 if first_fetch else None, 42 | self.batch_size, 43 | ) 44 | 45 | for signature, slot, logs in txs_logs: 46 | callback(signature, slot, logs) 47 | 48 | first_fetch = False 49 | except Exception as e: 50 | print("Error fetching logs", e) 51 | 52 | await asyncio.sleep(self.frequency) 53 | 54 | self.task = asyncio.create_task(fetch()) 55 | 56 | def is_subscribed(self) -> bool: 57 | return self.task is not None 58 | 59 | def unsubscribe(self): 60 | if self.is_subscribed(): 61 | self.task.cancel() 62 | self.task = None 63 | -------------------------------------------------------------------------------- /src/driftpy/events/sort.py: -------------------------------------------------------------------------------- 1 | from driftpy.events.types import ( 2 | WrappedEvent, 3 | EventSubscriptionOrderBy, 4 | EventSubscriptionOrderDirection, 5 | SortFn, 6 | ) 7 | 8 | 9 | def client_sort_asc_fn() -> int: 10 | return -1 11 | 12 | 13 | def client_sort_desc_fn() -> int: 14 | return 1 15 | 16 | 17 | def blockchain_sort_fn(current_event: WrappedEvent, new_event: WrappedEvent) -> int: 18 | if current_event.slot == new_event.slot: 19 | return -1 if current_event.tx_sig_index < new_event.tx_sig_index else 1 20 | 21 | return -1 if current_event.slot <= new_event.slot else 1 22 | 23 | 24 | def get_sort_fn( 25 | order_by: EventSubscriptionOrderBy, order_dir: EventSubscriptionOrderDirection 26 | ) -> SortFn: 27 | if order_by == "client": 28 | return client_sort_asc_fn if order_dir == "asc" else client_sort_desc_fn 29 | 30 | return blockchain_sort_fn 31 | -------------------------------------------------------------------------------- /src/driftpy/events/tx_event_cache.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, List 2 | from dataclasses import dataclass 3 | 4 | from driftpy.events.types import WrappedEvent 5 | 6 | 7 | @dataclass 8 | class Node: 9 | key: str 10 | value: List[WrappedEvent] 11 | next: Optional[any] = None 12 | prev: Optional[any] = None 13 | 14 | 15 | class TxEventCache: 16 | def __init__(self, max_tx: int = 1024): 17 | self.size = 0 18 | self.max_tx = max_tx 19 | self.head = None 20 | self.tail = None 21 | self.cache_map: Dict[str, Node] = {} 22 | 23 | def add(self, key: str, events: List[WrappedEvent]) -> None: 24 | existing_node = self.cache_map.get(key) 25 | if existing_node: 26 | self.detach(existing_node) 27 | self.size -= 1 28 | elif self.size == self.max_tx: 29 | del self.cache_map[self.tail.key] 30 | self.detach(self.tail) 31 | self.size -= 1 32 | 33 | if not self.head: 34 | self.head = self.tail = Node(key, events) 35 | else: 36 | node = Node(key, events, next=self.head) 37 | self.head.prev = node 38 | self.head = node 39 | 40 | self.cache_map[key] = self.head 41 | self.size += 1 42 | 43 | def has(self, key: str) -> bool: 44 | return key in self.cache_map 45 | 46 | def get(self, key: str) -> Optional[List[WrappedEvent]]: 47 | return self.cache_map.get(key).value if key in self.cache_map else None 48 | 49 | def detach(self, node: Node) -> None: 50 | if node.prev is not None: 51 | node.prev.next = node.next 52 | else: 53 | self.head = node.next 54 | 55 | if node.next is not None: 56 | node.next.prev = node.prev 57 | else: 58 | self.tail = node.prev 59 | 60 | def clear(self) -> None: 61 | self.head = None 62 | self.tail = None 63 | self.size = 0 64 | self.cache_map = {} 65 | -------------------------------------------------------------------------------- /src/driftpy/events/websocket_log_provider.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import cast 3 | 4 | import websockets.exceptions 5 | from solana.rpc.async_api import AsyncClient 6 | from solana.rpc.commitment import Commitment 7 | from solana.rpc.websocket_api import SolanaWsClientProtocol, connect 8 | from solders.pubkey import Pubkey 9 | from solders.rpc.config import RpcTransactionLogsFilterMentions 10 | 11 | from driftpy.events.types import LogProvider, LogProviderCallback 12 | from driftpy.types import get_ws_url 13 | 14 | 15 | class WebsocketLogProvider(LogProvider): 16 | def __init__( 17 | self, connection: AsyncClient, address: Pubkey, commitment: Commitment 18 | ): 19 | self.connection = connection 20 | self.address = address 21 | self.commitment = commitment 22 | self.task = None 23 | 24 | def subscribe(self, callback: LogProviderCallback): 25 | if not self.is_subscribed(): 26 | self.task = asyncio.create_task(self.subscribe_ws(callback)) 27 | 28 | async def subscribe_ws(self, callback: LogProviderCallback): 29 | endpoint = self.connection._provider.endpoint_uri 30 | if endpoint.startswith("http"): 31 | ws_endpoint = get_ws_url(endpoint) 32 | else: 33 | ws_endpoint = endpoint 34 | 35 | async for ws in connect(ws_endpoint): 36 | ws: SolanaWsClientProtocol 37 | try: 38 | await ws.logs_subscribe( 39 | RpcTransactionLogsFilterMentions(self.address), 40 | self.commitment, 41 | ) 42 | 43 | first_resp = await ws.recv() 44 | subscription_id = cast(int, first_resp[0].result) 45 | 46 | async for msg in ws: 47 | try: 48 | slot = msg[0].result.context.slot 49 | signature = msg[0].result.value.signature 50 | logs = msg[0].result.value.logs 51 | 52 | if msg[0].result.value.err: 53 | continue 54 | 55 | callback(signature, slot, logs) 56 | except Exception as e: 57 | print("Error processing event data", e) 58 | break 59 | await ws.account_unsubscribe(subscription_id) 60 | except websockets.exceptions.ConnectionClosed: 61 | print("Websocket closed, reconnecting...") 62 | continue 63 | 64 | def is_subscribed(self) -> bool: 65 | return self.task is not None 66 | 67 | def unsubscribe(self): 68 | if self.is_subscribed(): 69 | self.task.cancel() 70 | self.task = None 71 | -------------------------------------------------------------------------------- /src/driftpy/idl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/src/driftpy/idl/__init__.py -------------------------------------------------------------------------------- /src/driftpy/idl/pyth.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.1.0", 3 | "name": "pyth", 4 | "instructions": [ 5 | { 6 | "name": "initialize", 7 | "accounts": [ 8 | { 9 | "name": "price", 10 | "isMut": true, 11 | "isSigner": false 12 | } 13 | ], 14 | "args": [ 15 | { 16 | "name": "price", 17 | "type": "i64" 18 | }, 19 | { 20 | "name": "expo", 21 | "type": "i32" 22 | }, 23 | { 24 | "name": "conf", 25 | "type": "u64" 26 | } 27 | ] 28 | }, 29 | { 30 | "name": "setPrice", 31 | "accounts": [ 32 | { 33 | "name": "price", 34 | "isMut": true, 35 | "isSigner": false 36 | } 37 | ], 38 | "args": [ 39 | { 40 | "name": "price", 41 | "type": "i64" 42 | } 43 | ] 44 | }, 45 | { 46 | "name": "setPriceInfo", 47 | "accounts": [ 48 | { 49 | "name": "price", 50 | "isMut": true, 51 | "isSigner": false 52 | } 53 | ], 54 | "args": [ 55 | { 56 | "name": "price", 57 | "type": "i64" 58 | }, 59 | { 60 | "name": "conf", 61 | "type": "u64" 62 | }, 63 | { 64 | "name": "slot", 65 | "type": "u64" 66 | } 67 | ] 68 | }, 69 | { 70 | "name": "setTwap", 71 | "accounts": [ 72 | { 73 | "name": "price", 74 | "isMut": true, 75 | "isSigner": false 76 | } 77 | ], 78 | "args": [ 79 | { 80 | "name": "twap", 81 | "type": "i64" 82 | } 83 | ] 84 | } 85 | ], 86 | "types": [ 87 | { 88 | "name": "PriceStatus", 89 | "type": { 90 | "kind": "enum", 91 | "variants": [ 92 | { 93 | "name": "Unknown" 94 | }, 95 | { 96 | "name": "Trading" 97 | }, 98 | { 99 | "name": "Halted" 100 | }, 101 | { 102 | "name": "Auction" 103 | } 104 | ] 105 | } 106 | }, 107 | { 108 | "name": "CorpAction", 109 | "type": { 110 | "kind": "enum", 111 | "variants": [ 112 | { 113 | "name": "NoCorpAct" 114 | } 115 | ] 116 | } 117 | }, 118 | { 119 | "name": "PriceType", 120 | "type": { 121 | "kind": "enum", 122 | "variants": [ 123 | { 124 | "name": "Unknown" 125 | }, 126 | { 127 | "name": "Price" 128 | }, 129 | { 130 | "name": "TWAP" 131 | }, 132 | { 133 | "name": "Volatility" 134 | } 135 | ] 136 | } 137 | } 138 | ] 139 | } -------------------------------------------------------------------------------- /src/driftpy/idl/sequence_enforcer.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.1.0", 3 | "name": "sequence_enforcer", 4 | "instructions": [ 5 | { 6 | "name": "initialize", 7 | "accounts": [ 8 | { 9 | "name": "sequenceAccount", 10 | "isMut": true, 11 | "isSigner": false 12 | }, 13 | { 14 | "name": "authority", 15 | "isMut": false, 16 | "isSigner": true 17 | }, 18 | { 19 | "name": "systemProgram", 20 | "isMut": false, 21 | "isSigner": false 22 | } 23 | ], 24 | "args": [ 25 | { 26 | "name": "bump", 27 | "type": "u8" 28 | }, 29 | { 30 | "name": "sym", 31 | "type": "string" 32 | } 33 | ] 34 | }, 35 | { 36 | "name": "resetSequenceNumber", 37 | "accounts": [ 38 | { 39 | "name": "sequenceAccount", 40 | "isMut": true, 41 | "isSigner": false 42 | }, 43 | { 44 | "name": "authority", 45 | "isMut": false, 46 | "isSigner": true 47 | } 48 | ], 49 | "args": [ 50 | { 51 | "name": "sequenceNum", 52 | "type": "u64" 53 | } 54 | ] 55 | }, 56 | { 57 | "name": "checkAndSetSequenceNumber", 58 | "accounts": [ 59 | { 60 | "name": "sequenceAccount", 61 | "isMut": true, 62 | "isSigner": false 63 | }, 64 | { 65 | "name": "authority", 66 | "isMut": false, 67 | "isSigner": true 68 | } 69 | ], 70 | "args": [ 71 | { 72 | "name": "sequenceNum", 73 | "type": "u64" 74 | } 75 | ] 76 | } 77 | ], 78 | "accounts": [ 79 | { 80 | "name": "SequenceAccount", 81 | "type": { 82 | "kind": "struct", 83 | "fields": [ 84 | { 85 | "name": "sequenceNum", 86 | "type": "u64" 87 | }, 88 | { 89 | "name": "authority", 90 | "type": "publicKey" 91 | } 92 | ] 93 | } 94 | } 95 | ], 96 | "errors": [ 97 | { 98 | "code": 6000, 99 | "name": "SequenceOutOfOrder", 100 | "msg": "Sequence out of order" 101 | } 102 | ] 103 | } -------------------------------------------------------------------------------- /src/driftpy/idl/token_faucet.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.1.0", 3 | "name": "token_faucet", 4 | "instructions": [ 5 | { 6 | "name": "initialize", 7 | "accounts": [ 8 | { 9 | "name": "faucetConfig", 10 | "isMut": true, 11 | "isSigner": false 12 | }, 13 | { 14 | "name": "admin", 15 | "isMut": true, 16 | "isSigner": true 17 | }, 18 | { 19 | "name": "mintAccount", 20 | "isMut": true, 21 | "isSigner": false 22 | }, 23 | { 24 | "name": "rent", 25 | "isMut": false, 26 | "isSigner": false 27 | }, 28 | { 29 | "name": "systemProgram", 30 | "isMut": false, 31 | "isSigner": false 32 | }, 33 | { 34 | "name": "tokenProgram", 35 | "isMut": false, 36 | "isSigner": false 37 | } 38 | ], 39 | "args": [] 40 | }, 41 | { 42 | "name": "mintToUser", 43 | "accounts": [ 44 | { 45 | "name": "faucetConfig", 46 | "isMut": false, 47 | "isSigner": false 48 | }, 49 | { 50 | "name": "mintAccount", 51 | "isMut": true, 52 | "isSigner": false 53 | }, 54 | { 55 | "name": "userTokenAccount", 56 | "isMut": true, 57 | "isSigner": false 58 | }, 59 | { 60 | "name": "mintAuthority", 61 | "isMut": false, 62 | "isSigner": false 63 | }, 64 | { 65 | "name": "tokenProgram", 66 | "isMut": false, 67 | "isSigner": false 68 | } 69 | ], 70 | "args": [ 71 | { 72 | "name": "amount", 73 | "type": "u64" 74 | } 75 | ] 76 | }, 77 | { 78 | "name": "transferMintAuthority", 79 | "accounts": [ 80 | { 81 | "name": "faucetConfig", 82 | "isMut": false, 83 | "isSigner": false 84 | }, 85 | { 86 | "name": "admin", 87 | "isMut": true, 88 | "isSigner": true 89 | }, 90 | { 91 | "name": "mintAccount", 92 | "isMut": true, 93 | "isSigner": false 94 | }, 95 | { 96 | "name": "mintAuthority", 97 | "isMut": false, 98 | "isSigner": false 99 | }, 100 | { 101 | "name": "tokenProgram", 102 | "isMut": false, 103 | "isSigner": false 104 | } 105 | ], 106 | "args": [] 107 | } 108 | ], 109 | "accounts": [ 110 | { 111 | "name": "FaucetConfig", 112 | "type": { 113 | "kind": "struct", 114 | "fields": [ 115 | { 116 | "name": "admin", 117 | "type": "publicKey" 118 | }, 119 | { 120 | "name": "mint", 121 | "type": "publicKey" 122 | }, 123 | { 124 | "name": "mintAuthority", 125 | "type": "publicKey" 126 | }, 127 | { 128 | "name": "mintAuthorityNonce", 129 | "type": "u8" 130 | } 131 | ] 132 | } 133 | } 134 | ], 135 | "errors": [ 136 | { 137 | "code": 6000, 138 | "name": "InvalidMintAccountAuthority", 139 | "msg": "Program not mint authority" 140 | } 141 | ] 142 | } -------------------------------------------------------------------------------- /src/driftpy/indicative_quotes/__init__.py: -------------------------------------------------------------------------------- 1 | from .indicative_quotes_sender import IndicativeQuotesSender, Quote 2 | 3 | __all__ = ["IndicativeQuotesSender", "Quote"] 4 | -------------------------------------------------------------------------------- /src/driftpy/keypair.py: -------------------------------------------------------------------------------- 1 | from solders.keypair import Keypair 2 | import os 3 | import json 4 | import base58 5 | 6 | 7 | def load_keypair(private_key): 8 | # try to load privateKey as a filepath 9 | if os.path.exists(private_key): 10 | with open(private_key, "r") as file: 11 | private_key = file.read().strip() 12 | 13 | key_bytes = None 14 | if "[" in private_key and "]" in private_key: 15 | key_bytes = bytes(json.loads(private_key)) 16 | elif "," in private_key: 17 | key_bytes = bytes(map(int, private_key.split(","))) 18 | else: 19 | private_key = private_key.replace(" ", "") 20 | key_bytes = base58.b58decode(private_key) 21 | 22 | return Keypair.from_bytes(key_bytes) 23 | -------------------------------------------------------------------------------- /src/driftpy/market_map/grpc_sub.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, TypeVar 2 | 3 | from solana.rpc.commitment import Commitment 4 | 5 | from driftpy.accounts.grpc.program_account_subscriber import ( 6 | GrpcProgramAccountSubscriber, 7 | ) 8 | from driftpy.accounts.types import GrpcProgramAccountOptions, MarketUpdateCallback 9 | from driftpy.market_map.market_map import MarketMap 10 | from driftpy.memcmp import get_market_type_filter 11 | from driftpy.types import GrpcConfig, market_type_to_string 12 | 13 | T = TypeVar("T") 14 | 15 | 16 | class GrpcSubscription: 17 | def __init__( 18 | self, 19 | grpc_config: GrpcConfig, 20 | market_map: MarketMap, 21 | commitment: Commitment, 22 | on_update: MarketUpdateCallback, 23 | decode: Optional[Callable[[bytes], T]] = None, 24 | ): 25 | self.grpc_config = grpc_config 26 | self.market_map = market_map 27 | self.commitment = commitment 28 | self.on_update = on_update 29 | self.subscriber = None 30 | self.decode = decode 31 | 32 | async def subscribe(self): 33 | if not self.subscriber: 34 | filters = (get_market_type_filter(self.market_map.market_type),) 35 | options = GrpcProgramAccountOptions(filters, self.commitment) 36 | self.subscriber = GrpcProgramAccountSubscriber( 37 | subscription_name=f"{market_type_to_string(self.market_map.market_type)}MarketMap", 38 | program=self.market_map.program, 39 | grpc_config=self.grpc_config, 40 | on_update=self.on_update, 41 | options=options, 42 | decode=self.decode, 43 | ) 44 | 45 | await self.subscriber.subscribe() 46 | 47 | async def unsubscribe(self): 48 | if not self.subscriber: 49 | return 50 | await self.subscriber.unsubscribe() 51 | self.subscriber = None 52 | -------------------------------------------------------------------------------- /src/driftpy/market_map/market_map_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from anchorpy import Program 5 | from solana.rpc.async_api import AsyncClient 6 | from solana.rpc.commitment import Commitment 7 | 8 | from driftpy.types import GrpcConfig, MarketType 9 | 10 | 11 | @dataclass 12 | class WebsocketConfig: 13 | resub_timeout_ms: Optional[int] = None 14 | commitment: Optional[Commitment] = None 15 | 16 | 17 | @dataclass 18 | class MarketMapConfig: 19 | program: Program 20 | market_type: MarketType # perp market map or spot market map 21 | subscription_config: WebsocketConfig 22 | connection: AsyncClient 23 | 24 | 25 | @dataclass 26 | class GrpcMarketMapConfig: 27 | program: Program 28 | market_type: MarketType # perp market map or spot market map 29 | grpc_config: GrpcConfig 30 | connection: AsyncClient 31 | -------------------------------------------------------------------------------- /src/driftpy/market_map/websocket_sub.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, TypeVar 2 | 3 | from solana.rpc.commitment import Commitment 4 | 5 | from driftpy.accounts.types import MarketUpdateCallback, WebsocketProgramAccountOptions 6 | from driftpy.accounts.ws.program_account_subscriber import ( 7 | WebSocketProgramAccountSubscriber, 8 | ) 9 | from driftpy.memcmp import get_market_type_filter 10 | from driftpy.types import market_type_to_string 11 | 12 | T = TypeVar("T") 13 | 14 | 15 | class WebsocketSubscription: 16 | def __init__( 17 | self, 18 | market_map, 19 | commitment: Commitment, 20 | on_update: MarketUpdateCallback, 21 | resub_timeout_ms: Optional[int] = None, 22 | decode: Optional[Callable[[bytes], T]] = None, 23 | ): 24 | self.market_map = market_map 25 | self.commitment = commitment 26 | self.on_update = on_update 27 | self.resub_timeout_ms = resub_timeout_ms 28 | self.subscriber = None 29 | self.decode = decode 30 | 31 | async def subscribe(self): 32 | if not self.subscriber: 33 | filters = (get_market_type_filter(self.market_map.market_type),) 34 | options = WebsocketProgramAccountOptions(filters, self.commitment, "base64") 35 | self.subscriber = WebSocketProgramAccountSubscriber( 36 | f"{market_type_to_string(self.market_map.market_type)}MarketMap", 37 | self.market_map.program, 38 | options, 39 | self.on_update, 40 | self.decode, 41 | ) 42 | 43 | await self.subscriber.subscribe() 44 | 45 | async def unsubscribe(self): 46 | if not self.subscriber: 47 | return 48 | await self.subscriber.unsubscribe() 49 | self.subscriber = None 50 | -------------------------------------------------------------------------------- /src/driftpy/math/conversion.py: -------------------------------------------------------------------------------- 1 | from driftpy.constants.numeric_constants import PRICE_PRECISION 2 | 3 | 4 | def convert_to_number(big_number: int, precision: int = PRICE_PRECISION) -> float: 5 | if not big_number: 6 | return 0 7 | return big_number // precision + (big_number % precision) / precision 8 | -------------------------------------------------------------------------------- /src/driftpy/math/exchange_status.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from driftpy.types import ( 3 | PerpMarketAccount, 4 | SpotMarketAccount, 5 | StateAccount, 6 | is_one_of_variant, 7 | is_variant, 8 | ) 9 | 10 | 11 | class ExchangeStatusValues: 12 | Active = 1 13 | DepositPaused = 2 14 | WithdrawPaused = 4 15 | AmmPaused = 8 16 | FillPaused = 16 17 | LiqPaused = 32 18 | FundingPaused = 64 19 | SettlePnlPaused = 128 20 | 21 | 22 | def exchange_paused(state: StateAccount) -> bool: 23 | return not is_variant(state.exchange_status, "Active") 24 | 25 | 26 | def fill_paused( 27 | state: StateAccount, market: Union[PerpMarketAccount, SpotMarketAccount] 28 | ) -> bool: 29 | return ( 30 | state.exchange_status & ExchangeStatusValues.FillPaused 31 | ) == ExchangeStatusValues.FillPaused or is_one_of_variant( 32 | market.status, ["Paused", "FillPaused"] 33 | ) 34 | 35 | 36 | def amm_paused( 37 | state: StateAccount, market: Union[PerpMarketAccount, SpotMarketAccount] 38 | ) -> bool: 39 | return ( 40 | state.exchange_status & ExchangeStatusValues.AmmPaused 41 | ) == ExchangeStatusValues.AmmPaused or is_one_of_variant( 42 | market.status, ["Paused", "AmmPaused"] 43 | ) 44 | -------------------------------------------------------------------------------- /src/driftpy/math/fuel.py: -------------------------------------------------------------------------------- 1 | from driftpy.types import SpotMarketAccount, PerpMarketAccount 2 | from driftpy.constants.numeric_constants import QUOTE_PRECISION, FUEL_WINDOW 3 | 4 | 5 | def calculate_insurance_fuel_bonus( 6 | spot_market: SpotMarketAccount, token_stake_amount: int, fuel_bonus_numerator: int 7 | ) -> int: 8 | insurance_fund_fuel = ( 9 | abs(token_stake_amount) * fuel_bonus_numerator 10 | ) * spot_market.fuel_boost_insurance 11 | insurace_fund_fuel_per_day = insurance_fund_fuel // FUEL_WINDOW 12 | insurance_fund_fuel_scaled = insurace_fund_fuel_per_day // (QUOTE_PRECISION // 10) 13 | 14 | return insurance_fund_fuel_scaled 15 | 16 | 17 | def calculate_spot_fuel_bonus( 18 | spot_market: SpotMarketAccount, signed_token_value: int, fuel_bonus_numerator: int 19 | ) -> int: 20 | spot_fuel_scaled: int 21 | 22 | # dust 23 | if abs(signed_token_value) <= QUOTE_PRECISION: 24 | spot_fuel_scaled = 0 25 | elif signed_token_value > 0: 26 | deposit_fuel = ( 27 | abs(signed_token_value) * fuel_bonus_numerator 28 | ) * spot_market.fuel_boost_deposits 29 | deposit_fuel_per_day = deposit_fuel // FUEL_WINDOW 30 | spot_fuel_scaled = deposit_fuel_per_day // (QUOTE_PRECISION // 10) 31 | else: 32 | borrow_fuel = ( 33 | abs(signed_token_value) * fuel_bonus_numerator 34 | ) * spot_market.fuel_boost_borrows 35 | borrow_fuel_per_day = borrow_fuel // FUEL_WINDOW 36 | spot_fuel_scaled = borrow_fuel_per_day // (QUOTE_PRECISION // 10) 37 | 38 | return spot_fuel_scaled 39 | 40 | 41 | def calculate_perp_fuel_bonus( 42 | perp_market: PerpMarketAccount, base_asset_value: int, fuel_bonus_numerator: int 43 | ) -> int: 44 | perp_fuel_scaled: int 45 | 46 | # dust 47 | if abs(base_asset_value) <= QUOTE_PRECISION: 48 | perp_fuel_scaled = 0 49 | else: 50 | perp_fuel = ( 51 | abs(base_asset_value) * fuel_bonus_numerator 52 | ) * perp_market.fuel_boost_position 53 | perp_fuel_per_day = perp_fuel // FUEL_WINDOW 54 | perp_fuel_scaled = perp_fuel_per_day // (QUOTE_PRECISION // 10) 55 | 56 | return perp_fuel_scaled 57 | -------------------------------------------------------------------------------- /src/driftpy/math/market.py: -------------------------------------------------------------------------------- 1 | from driftpy.types import OraclePriceData, PerpMarketAccount, PositionDirection 2 | 3 | 4 | def calculate_bid_price( 5 | market: PerpMarketAccount, oracle_price_data: OraclePriceData 6 | ) -> int: 7 | from driftpy.math.amm import calculate_updated_amm_spread_reserves, calculate_price 8 | 9 | ( 10 | base_asset_reserve, 11 | quote_asset_reserve, 12 | new_peg, 13 | _, 14 | ) = calculate_updated_amm_spread_reserves( 15 | market.amm, PositionDirection.Short(), oracle_price_data 16 | ) 17 | 18 | return calculate_price(base_asset_reserve, quote_asset_reserve, new_peg) 19 | 20 | 21 | def calculate_ask_price( 22 | market: PerpMarketAccount, oracle_price_data: OraclePriceData 23 | ) -> int: 24 | from driftpy.math.amm import calculate_updated_amm_spread_reserves, calculate_price 25 | 26 | ( 27 | base_asset_reserve, 28 | quote_asset_reserve, 29 | new_peg, 30 | _, 31 | ) = calculate_updated_amm_spread_reserves( 32 | market.amm, PositionDirection.Long(), oracle_price_data 33 | ) 34 | 35 | return calculate_price(base_asset_reserve, quote_asset_reserve, new_peg) 36 | -------------------------------------------------------------------------------- /src/driftpy/math/spot_market.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from driftpy.math.utils import div_ceil 4 | from driftpy.types import ( 5 | OraclePriceData, 6 | SpotBalanceType, 7 | SpotMarketAccount, 8 | is_variant, 9 | ) 10 | 11 | 12 | def get_signed_token_amount(amount, balance_type): 13 | return amount if is_variant(balance_type, "Deposit") else -abs(amount) 14 | 15 | 16 | def get_token_amount( 17 | balance: int, spot_market: SpotMarketAccount, balance_type: SpotBalanceType 18 | ) -> int: 19 | precision_decrease = 10 ** (19 - spot_market.decimals) 20 | 21 | if is_variant(balance_type, "Deposit"): 22 | return int( 23 | (balance * spot_market.cumulative_deposit_interest) / precision_decrease 24 | ) 25 | else: 26 | return div_ceil( 27 | balance * spot_market.cumulative_borrow_interest, precision_decrease 28 | ) 29 | 30 | 31 | def get_token_value( 32 | amount, spot_decimals, oracle_price_data: Union[OraclePriceData, int] 33 | ): 34 | precision_decrease = 10**spot_decimals 35 | if isinstance(oracle_price_data, OraclePriceData): 36 | return amount * oracle_price_data.price // precision_decrease 37 | else: 38 | return amount * oracle_price_data // precision_decrease 39 | 40 | 41 | def cast_to_spot_precision( 42 | amount: Union[float, int], spot_market: SpotMarketAccount 43 | ) -> int: 44 | precision = 10**spot_market.decimals 45 | return int(amount * precision) 46 | -------------------------------------------------------------------------------- /src/driftpy/math/user_status.py: -------------------------------------------------------------------------------- 1 | from driftpy.types import UserAccount, UserStatus 2 | 3 | 4 | def is_user_protected_maker(user_account: UserAccount) -> bool: 5 | return (user_account.status & UserStatus.PROTECTED_MAKER) > 0 6 | -------------------------------------------------------------------------------- /src/driftpy/math/utils.py: -------------------------------------------------------------------------------- 1 | def clamp_num(x: int, min_clamp: int, max_clamp: int) -> int: 2 | return max(min_clamp, min(x, max_clamp)) 3 | 4 | 5 | def div_ceil(a: int, b: int) -> int: 6 | if b == 0: 7 | return a 8 | 9 | quotient = a // b 10 | remainder = a % b 11 | 12 | if remainder > 0: 13 | quotient += 1 14 | 15 | return quotient 16 | 17 | 18 | def sig_num(x: int) -> int: 19 | return -1 if x < 0 else 1 20 | 21 | 22 | def time_remaining_until_update(now: int, last_update_ts: int, update_period: int) -> int: 23 | if update_period <= 0: 24 | raise ValueError("update_period must be positive") 25 | 26 | time_since = now - last_update_ts 27 | 28 | if update_period == 1: 29 | return max(0, 1 - time_since) 30 | 31 | # Calculate delay-based adjustment 32 | last_delay = last_update_ts % update_period 33 | max_delay = update_period // 3 34 | next_wait = update_period - last_delay 35 | 36 | if last_delay > max_delay: 37 | next_wait = 2 * update_period - last_delay 38 | 39 | return max(0, next_wait - time_since) 40 | -------------------------------------------------------------------------------- /src/driftpy/memcmp.py: -------------------------------------------------------------------------------- 1 | import base58 2 | from anchorpy.coder.accounts import _account_discriminator 3 | from solana.rpc.types import MemcmpOpts 4 | 5 | from driftpy.name import encode_name 6 | from driftpy.types import MarketType, is_variant 7 | 8 | 9 | def get_user_filter() -> MemcmpOpts: 10 | return MemcmpOpts(0, base58.b58encode(_account_discriminator("User")).decode()) 11 | 12 | 13 | def get_non_idle_user_filter() -> MemcmpOpts: 14 | return MemcmpOpts(4350, base58.b58encode(bytes([0])).decode()) 15 | 16 | 17 | def get_user_with_auction_filter() -> MemcmpOpts: 18 | return MemcmpOpts(4354, base58.b58encode(bytes([1])).decode()) 19 | 20 | 21 | def get_user_with_order_filter() -> MemcmpOpts: 22 | return MemcmpOpts(4352, base58.b58encode(bytes([1])).decode()) 23 | 24 | 25 | def get_user_without_order_filter() -> MemcmpOpts: 26 | return MemcmpOpts(4352, base58.b58encode(bytes([0])).decode()) 27 | 28 | 29 | def get_user_that_has_been_lp_filter() -> MemcmpOpts: 30 | return MemcmpOpts(4267, base58.b58encode(bytes([99])).decode()) 31 | 32 | 33 | def get_user_with_name_filter(name: str) -> MemcmpOpts: 34 | encoded_name_bytes = encode_name(name) 35 | return MemcmpOpts(72, base58.b58encode(bytes(encoded_name_bytes)).decode()) 36 | 37 | 38 | def get_users_with_pool_id_filter(pool_id: int) -> MemcmpOpts: 39 | return MemcmpOpts(4356, base58.b58encode(bytes([pool_id])).decode()) 40 | 41 | 42 | def get_market_type_filter(market_type: MarketType) -> MemcmpOpts: 43 | if is_variant(market_type, "Perp"): 44 | return MemcmpOpts( 45 | 0, base58.b58encode(_account_discriminator("PerpMarket")).decode() 46 | ) 47 | else: 48 | return MemcmpOpts( 49 | 0, base58.b58encode(_account_discriminator("SpotMarket")).decode() 50 | ) 51 | 52 | 53 | def get_user_stats_filter() -> MemcmpOpts: 54 | return MemcmpOpts(0, base58.b58encode(_account_discriminator("UserStats")).decode()) 55 | 56 | 57 | def get_user_stats_is_referred_filter() -> MemcmpOpts: 58 | # offset 188, bytes for 2 59 | return MemcmpOpts(188, base58.b58encode(bytes([2])).decode()) 60 | 61 | 62 | def get_user_stats_is_referred_or_referrer_filter() -> MemcmpOpts: 63 | # offset 188, bytes for 3 64 | return MemcmpOpts(188, base58.b58encode(bytes([3])).decode()) 65 | 66 | 67 | def get_signed_msg_user_orders_filter() -> MemcmpOpts: 68 | return MemcmpOpts( 69 | 0, base58.b58encode(_account_discriminator("SignedMsgUserOrders")).decode() 70 | ) 71 | -------------------------------------------------------------------------------- /src/driftpy/name.py: -------------------------------------------------------------------------------- 1 | from struct import pack_into 2 | 3 | MAX_LENGTH = 32 4 | 5 | 6 | def encode_name(name: str) -> list[int]: 7 | if len(name) > 32: 8 | raise Exception("name too long") 9 | 10 | name_bytes = bytearray(32) 11 | pack_into(f"{len(name)}s", name_bytes, 0, name.encode("utf-8")) 12 | offset = len(name) 13 | for _ in range(32 - len(name)): 14 | pack_into("1s", name_bytes, offset, " ".encode("utf-8")) 15 | offset += 1 16 | 17 | str_name_bytes = name_bytes.hex() 18 | name_byte_array = [] 19 | for i in range(0, len(str_name_bytes), 2): 20 | name_byte_array.append(int(str_name_bytes[i : i + 2], 16)) 21 | 22 | return name_byte_array 23 | -------------------------------------------------------------------------------- /src/driftpy/oracles/oracle_id.py: -------------------------------------------------------------------------------- 1 | from solders.pubkey import Pubkey 2 | 3 | from driftpy.types import OracleSource, OracleSourceNum 4 | 5 | 6 | def get_oracle_source_num(source: OracleSource) -> int: 7 | source_str = str(source) 8 | 9 | if "Pyth1M" in source_str: 10 | return OracleSourceNum.PYTH_1M 11 | elif "Pyth1K" in source_str: 12 | return OracleSourceNum.PYTH_1K 13 | elif "PythPull" in source_str: 14 | return OracleSourceNum.PYTH_PULL 15 | elif "Pyth1KPull" in source_str: 16 | return OracleSourceNum.PYTH_1K_PULL 17 | elif "Pyth1MPull" in source_str: 18 | return OracleSourceNum.PYTH_1M_PULL 19 | elif "PythStableCoinPull" in source_str: 20 | return OracleSourceNum.PYTH_STABLE_COIN_PULL 21 | elif "PythStableCoin" in source_str: 22 | return OracleSourceNum.PYTH_STABLE_COIN 23 | elif "PythLazer1K" in source_str: 24 | return OracleSourceNum.PYTH_LAZER_1K 25 | elif "PythLazer1M" in source_str: 26 | return OracleSourceNum.PYTH_LAZER_1M 27 | elif "PythLazerStableCoin" in source_str: 28 | return OracleSourceNum.PYTH_LAZER_STABLE_COIN 29 | elif "PythLazer" in source_str: 30 | return OracleSourceNum.PYTH_LAZER 31 | elif "Pyth" in source_str: 32 | return OracleSourceNum.PYTH 33 | elif "SwitchboardOnDemand" in source_str: 34 | return OracleSourceNum.SWITCHBOARD_ON_DEMAND 35 | elif "Switchboard" in source_str: 36 | return OracleSourceNum.SWITCHBOARD 37 | elif "QuoteAsset" in source_str: 38 | return OracleSourceNum.QUOTE_ASSET 39 | elif "Prelaunch" in source_str: 40 | return OracleSourceNum.PRELAUNCH 41 | 42 | raise ValueError("Invalid oracle source") 43 | 44 | 45 | def get_oracle_id(public_key: Pubkey, source: OracleSource) -> str: 46 | """ 47 | Returns the oracle id for a given oracle and source 48 | """ 49 | return f"{str(public_key)}-{get_oracle_source_num(source)}" 50 | -------------------------------------------------------------------------------- /src/driftpy/oracles/strict_oracle_price.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | class StrictOraclePrice: 5 | def __init__(self, current: int, twap: Optional[int] = None): 6 | self.current = current 7 | self.twap = twap 8 | 9 | def max(self) -> int: 10 | if self.twap: 11 | return max(self.twap, self.current) 12 | else: 13 | return self.current 14 | 15 | def min(self): 16 | if self.twap: 17 | return min(self.twap, self.current) 18 | else: 19 | return self.current 20 | -------------------------------------------------------------------------------- /src/driftpy/priority_fees/priority_fee_subscriber.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass 3 | 4 | import jsonrpcclient 5 | from solana.rpc.async_api import AsyncClient 6 | 7 | 8 | @dataclass 9 | class PriorityFeeConfig: 10 | connection: AsyncClient 11 | frequency_secs: int 12 | addresses: list[str] 13 | slots_to_check: int = 10 14 | 15 | 16 | class PriorityFeeSubscriber: 17 | def __init__(self, config: PriorityFeeConfig): 18 | self.connection = config.connection 19 | self.frequency_ms = config.frequency_secs 20 | self.addresses = config.addresses 21 | self.slots_to_check = config.slots_to_check 22 | 23 | self.latest_priority_fee = 0 24 | self.avg_priority_fee = 0 25 | self.max_priority_fee = 0 26 | self.last_slot_seen = 0 27 | self.subscribed = False 28 | 29 | async def subscribe(self): 30 | if self.subscribed: 31 | return 32 | 33 | self.subscribed = True 34 | 35 | asyncio.create_task(self.poll()) 36 | 37 | async def poll(self): 38 | while self.subscribed: 39 | asyncio.create_task(self.load()) 40 | await asyncio.sleep(self.frequency_ms) 41 | 42 | async def load(self): 43 | rpc_request = jsonrpcclient.request( 44 | "getRecentPrioritizationFees", [self.addresses] 45 | ) 46 | 47 | post = self.connection._provider.session.post( 48 | self.connection._provider.endpoint_uri, 49 | json=rpc_request, 50 | headers={"content-encoding": "gzip"}, 51 | ) 52 | 53 | resp = await asyncio.wait_for(post, timeout=20) 54 | 55 | parsed_resp = jsonrpcclient.parse(resp.json()) 56 | 57 | if isinstance(parsed_resp, jsonrpcclient.Error): 58 | raise ValueError(f"Error fetching priority fees: {parsed_resp.message}") 59 | 60 | if not isinstance(parsed_resp, jsonrpcclient.Ok): 61 | raise ValueError(f"Error fetching priority fees - not ok: {parsed_resp}") 62 | 63 | result = parsed_resp.result 64 | 65 | desc_results = sorted(result, key=lambda x: x["slot"], reverse=True)[ 66 | : self.slots_to_check 67 | ] 68 | 69 | if not desc_results: 70 | return 71 | 72 | self.latest_priority_fee = desc_results[0]["prioritizationFee"] 73 | self.last_slot_seen = desc_results[0]["slot"] 74 | self.avg_priority_fee = sum( 75 | item["prioritizationFee"] for item in desc_results 76 | ) / len(desc_results) 77 | self.max_priority_fee = max(item["prioritizationFee"] for item in desc_results) 78 | 79 | async def unsubscribe(self): 80 | if self.subscribed: 81 | self.subscribed = False 82 | -------------------------------------------------------------------------------- /src/driftpy/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/src/driftpy/py.typed -------------------------------------------------------------------------------- /src/driftpy/slot/slot_subscriber.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from events import Events as EventEmitter 4 | from solana.rpc.websocket_api import SolanaWsClientProtocol, connect 5 | 6 | from driftpy.dlob.client_types import SlotSource 7 | from driftpy.drift_client import DriftClient 8 | from driftpy.types import get_ws_url 9 | 10 | 11 | class SlotSubscriber(SlotSource): 12 | def __init__(self, drift_client: DriftClient): 13 | self.current_slot = 0 14 | self.subscription_id = None 15 | self.connection = drift_client.connection 16 | self.program = drift_client.program 17 | self.ws = None 18 | self.subscribed = False 19 | self.event_emitter = EventEmitter(("on_slot_change")) 20 | self.event_emitter.on("on_slot_change") 21 | 22 | async def on_slot_change(self, slot_info: int): 23 | self.current_slot = slot_info 24 | self.event_emitter.on_slot_change(slot_info) 25 | 26 | async def subscribe(self): 27 | if self.subscribed: 28 | return 29 | self.task = asyncio.create_task(self.subscribe_ws()) 30 | return self.task 31 | 32 | async def subscribe_ws(self): 33 | if self.subscription_id is not None: 34 | return 35 | 36 | current_slot_response = await self.connection.get_slot() 37 | self.current_slot = int(current_slot_response.value) 38 | 39 | endpoint = self.program.provider.connection._provider.endpoint_uri 40 | ws_endpoint = get_ws_url(endpoint) 41 | while True: 42 | try: 43 | async with connect(ws_endpoint) as ws: 44 | self.subscribed = True 45 | self.ws = ws 46 | ws: SolanaWsClientProtocol 47 | self.subscription_id = await ws.slot_subscribe() 48 | 49 | await ws.recv() 50 | 51 | async for msg in ws: 52 | await self.on_slot_change(msg[0].result.slot) 53 | 54 | except Exception as e: 55 | print(f"Error in SlotSubscriber: {e}") 56 | if self.ws: 57 | await self.ws.close() 58 | self.ws = None 59 | await asyncio.sleep(5) # wait a second before we retry 60 | 61 | def get_slot(self) -> int: 62 | return self.current_slot 63 | 64 | async def unsubscribe(self): 65 | if self.ws: 66 | await self.ws.close() 67 | self.ws = None 68 | self.subscribed = False 69 | -------------------------------------------------------------------------------- /src/driftpy/swift/create_verify_ix.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from construct import Int8ul, Int16ul, Struct 4 | from solana.constants import ED25519_PROGRAM_ID 5 | from solders.instruction import Instruction 6 | 7 | ED25519_INSTRUCTION_LEN = 16 8 | SIGNATURE_LEN = 64 9 | PUBKEY_LEN = 32 10 | MAGIC_LEN = 4 11 | MESSAGE_SIZE_LEN = 2 12 | 13 | 14 | def trim_feed_id(feed_id: str) -> str: 15 | if feed_id.startswith("0x"): 16 | return feed_id[2:] 17 | return feed_id 18 | 19 | 20 | def get_feed_id_uint8_array(feed_id: str) -> bytes: 21 | trimmed_feed_id = trim_feed_id(feed_id) 22 | return bytes.fromhex(trimmed_feed_id) 23 | 24 | 25 | def get_ed25519_args_from_hex( 26 | hex_str: str, custom_instruction_index: Optional[int] = None 27 | ) -> Dict[str, bytes]: 28 | cleaned_hex = hex_str[2:] if hex_str.startswith("0x") else hex_str 29 | buffer = bytes.fromhex(cleaned_hex) 30 | 31 | signature_offset = MAGIC_LEN 32 | public_key_offset = signature_offset + SIGNATURE_LEN 33 | message_data_size_offset = public_key_offset + PUBKEY_LEN 34 | message_data_offset = message_data_size_offset + MESSAGE_SIZE_LEN 35 | 36 | signature = buffer[signature_offset : signature_offset + SIGNATURE_LEN] 37 | public_key = buffer[public_key_offset : public_key_offset + PUBKEY_LEN] 38 | message_size = buffer[message_data_size_offset] | ( 39 | buffer[message_data_size_offset + 1] << 8 40 | ) 41 | message = buffer[message_data_offset : message_data_offset + message_size] 42 | 43 | if len(public_key) != PUBKEY_LEN: 44 | raise ValueError("Invalid public key length") 45 | 46 | if len(signature) != SIGNATURE_LEN: 47 | raise ValueError("Invalid signature length") 48 | 49 | return { 50 | "public_key": public_key, 51 | "signature": signature, 52 | "message": message, 53 | "instruction_index": custom_instruction_index, 54 | } 55 | 56 | 57 | def read_uint16_le(data: bytes, offset: int) -> int: 58 | return data[offset] | (data[offset + 1] << 8) 59 | 60 | 61 | ED25519_INSTRUCTION_LAYOUT = Struct( 62 | "num_signatures" / Int8ul, 63 | "padding" / Int8ul, 64 | "signature_offset" / Int16ul, 65 | "signature_instruction_index" / Int16ul, 66 | "public_key_offset" / Int16ul, 67 | "public_key_instruction_index" / Int16ul, 68 | "message_data_offset" / Int16ul, 69 | "message_data_size" / Int16ul, 70 | "message_instruction_index" / Int16ul, 71 | ) 72 | 73 | 74 | def create_minimal_ed25519_verify_ix( 75 | custom_instruction_index: int, 76 | message_offset: int, 77 | custom_instruction_data: bytes, 78 | magic_len: Optional[int] = None, 79 | ) -> Instruction: 80 | signature_offset = message_offset + (MAGIC_LEN if magic_len is None else magic_len) 81 | public_key_offset = signature_offset + SIGNATURE_LEN 82 | message_data_size_offset = public_key_offset + PUBKEY_LEN 83 | message_data_offset = message_data_size_offset + MESSAGE_SIZE_LEN 84 | 85 | message_data_size = read_uint16_le( 86 | custom_instruction_data, message_data_size_offset - message_offset 87 | ) 88 | 89 | instruction_data = ED25519_INSTRUCTION_LAYOUT.build( 90 | dict( 91 | num_signatures=1, 92 | padding=0, 93 | signature_offset=signature_offset, 94 | signature_instruction_index=custom_instruction_index, 95 | public_key_offset=public_key_offset, 96 | public_key_instruction_index=custom_instruction_index, 97 | message_data_offset=message_data_offset, 98 | message_data_size=message_data_size, 99 | message_instruction_index=custom_instruction_index, 100 | ) 101 | ) 102 | 103 | return Instruction( 104 | accounts=[], 105 | program_id=ED25519_PROGRAM_ID, 106 | data=instruction_data, 107 | ) 108 | -------------------------------------------------------------------------------- /src/driftpy/swift/util.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | import random 4 | import string 5 | 6 | 7 | def digest_signature(signature: bytes) -> str: 8 | """ 9 | Create a SHA-256 hash of a signature and return it as a base64 string. 10 | 11 | Args: 12 | signature: The signature bytes to hash 13 | 14 | Returns: 15 | Base64-encoded SHA-256 hash of the signature 16 | """ 17 | hash_object = hashlib.sha256(signature) 18 | return base64.b64encode(hash_object.digest()).decode("utf-8") 19 | 20 | 21 | def generate_signed_msg_uuid(length=8) -> bytes: 22 | """ 23 | Generate a random string similar to nanoid with specified length and convert to bytes. 24 | 25 | Args: 26 | length: Length of the random string to generate (default: 8) 27 | 28 | Returns: 29 | Bytes representation of the generated random string 30 | """ 31 | chars = string.ascii_letters + string.digits + "-_" # nanoid alphabet 32 | random_id = "".join(random.choice(chars) for _ in range(length)) 33 | return random_id.encode("utf-8") 34 | -------------------------------------------------------------------------------- /src/driftpy/tx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/src/driftpy/tx/__init__.py -------------------------------------------------------------------------------- /src/driftpy/tx/fast_tx_sender.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from solders.hash import Hash 4 | 5 | from solana.rpc.async_api import AsyncClient 6 | from solana.rpc.types import TxOpts 7 | from solana.rpc.commitment import Commitment, Confirmed 8 | 9 | from driftpy.tx.standard_tx_sender import StandardTxSender 10 | 11 | 12 | class FastTxSender(StandardTxSender): 13 | """ 14 | The FastTxSender will refresh the latest blockhash in the background to save an RPC when building transactions. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | connection: AsyncClient, 20 | opts: TxOpts, 21 | blockhash_refresh_interval_secs: int, 22 | blockhash_commitment: Commitment = Confirmed, 23 | ): 24 | super().__init__(connection, opts, blockhash_commitment) 25 | self.blockhash_refresh_interval = blockhash_refresh_interval_secs 26 | self.recent_blockhash = None 27 | 28 | async def subscribe_blockhash(self): 29 | """ 30 | Must be called with asyncio.create_task to prevent blocking 31 | """ 32 | while True: 33 | try: 34 | blockhash_info = await self.connection.get_latest_blockhash( 35 | self.blockhash_commitment 36 | ) 37 | self.recent_blockhash = blockhash_info.value.blockhash 38 | except Exception as e: 39 | print(f"Error in subscribe_blockhash: {e}") 40 | await asyncio.sleep(self.blockhash_refresh_interval) 41 | 42 | async def fetch_latest_blockhash(self) -> Hash: 43 | if self.recent_blockhash is None: 44 | asyncio.create_task(self.subscribe_blockhash()) 45 | return await super().get_blockhash() 46 | return self.recent_blockhash 47 | -------------------------------------------------------------------------------- /src/driftpy/tx/jito_subscriber.py: -------------------------------------------------------------------------------- 1 | raise ImportError( 2 | "The jito_subscriber module is deprecated and has been removed in driftpy." 3 | ) 4 | 5 | import asyncio 6 | import random 7 | from typing import Optional, Tuple, Union 8 | 9 | from jito_searcher_client.async_searcher import ( 10 | get_async_searcher_client, # type: ignore 11 | ) 12 | from jito_searcher_client.generated.searcher_pb2 import ( 13 | ConnectedLeadersRequest, 14 | ConnectedLeadersResponse, 15 | GetTipAccountsRequest, 16 | GetTipAccountsResponse, 17 | SubscribeBundleResultsRequest, 18 | ) # type: ignore 19 | from jito_searcher_client.generated.searcher_pb2_grpc import ( 20 | SearcherServiceStub, # type: ignore 21 | ) 22 | from solana.rpc.async_api import AsyncClient 23 | from solana.rpc.commitment import Confirmed 24 | from solders.keypair import Keypair # type: ignore 25 | from solders.pubkey import Pubkey # type: ignore 26 | from solders.system_program import TransferParams, transfer 27 | 28 | 29 | class JitoSubscriber: 30 | def __init__( 31 | self, 32 | refresh_rate: int, 33 | kp: Keypair, 34 | connection: AsyncClient, 35 | block_engine_url: str, 36 | ): 37 | self.cache = [] # type: ignore 38 | self.refresh_rate = refresh_rate 39 | self.kp = kp 40 | self.searcher_client: Optional[SearcherServiceStub] = None 41 | self.connection = connection 42 | self.tip_accounts: list[Pubkey] = [] 43 | self.block_engine_url = block_engine_url 44 | self.bundle_subscription = None 45 | 46 | async def subscribe(self): 47 | self.searcher_client = await get_async_searcher_client( 48 | self.block_engine_url, self.kp 49 | ) 50 | asyncio.create_task(self._subscribe()) 51 | 52 | async def _subscribe(self): 53 | self.bundle_subscription = self.searcher_client.SubscribeBundleResults( 54 | SubscribeBundleResultsRequest() 55 | ) 56 | tip_accounts: GetTipAccountsResponse = ( 57 | await self.searcher_client.GetTipAccounts(GetTipAccountsRequest()) 58 | ) # type: ignore 59 | for account in tip_accounts.accounts: 60 | self.tip_accounts.append(Pubkey.from_string(account)) 61 | while True: 62 | try: 63 | self.cache.clear() 64 | current_slot = (await self.connection.get_slot(Confirmed)).value 65 | leaders: ConnectedLeadersResponse = ( 66 | await self.searcher_client.GetConnectedLeaders( 67 | ConnectedLeadersRequest() 68 | ) 69 | ) # type: ignore 70 | for slot_list in leaders.connected_validators.values(): 71 | slots = slot_list.slots 72 | for slot in slots: 73 | if slot > current_slot: 74 | self.cache.append(slot) 75 | self.cache.sort() 76 | 77 | except Exception as e: 78 | print(e) 79 | await asyncio.sleep(30) 80 | await self._subscribe() 81 | await asyncio.sleep(self.refresh_rate) 82 | 83 | def send_to_jito(self, current_slot: int) -> bool: 84 | for slot in range(current_slot - 5, current_slot + 5): 85 | if slot in self.cache: 86 | return True 87 | return False 88 | 89 | def get_tip_ix(self, signer: Pubkey, tip_amount: int = 1_000_000): 90 | tip_account = random.choice(self.tip_accounts) 91 | transfer_params = TransferParams( 92 | from_pubkey=signer, to_pubkey=tip_account, lamports=tip_amount 93 | ) 94 | return transfer(transfer_params) 95 | 96 | async def process_bundle_result(self, uuid: str) -> Tuple[bool, Union[int, str]]: 97 | while True: 98 | bundle_result = await self.bundle_subscription.read() # type: ignore 99 | if bundle_result.bundle_id == uuid: 100 | if bundle_result.HasField("accepted"): 101 | slot = getattr(getattr(bundle_result, "accepted"), "slot") 102 | return True, slot or 0 103 | elif bundle_result.HasField("rejected"): 104 | msg = getattr( 105 | getattr( 106 | getattr(bundle_result, "rejected"), "simulation_failure" 107 | ), 108 | "msg", 109 | ) 110 | return False, msg or "" 111 | -------------------------------------------------------------------------------- /src/driftpy/tx/jito_tx_sender.py: -------------------------------------------------------------------------------- 1 | raise ImportError( 2 | "The jito_tx_sender module is deprecated and has been removed in driftpy. " 3 | ) 4 | 5 | 6 | class JitoTxSender: 7 | def __init__( 8 | self, 9 | drift_client, 10 | opts, 11 | block_engine_url, 12 | jito_keypair, 13 | blockhash_commitment, 14 | blockhash_refresh_interval_secs=None, 15 | tip_amount=None, 16 | ): 17 | raise NotImplementedError( 18 | "JitoTxSender is deprecated and has been removed in driftpy." 19 | ) 20 | -------------------------------------------------------------------------------- /src/driftpy/tx/standard_tx_sender.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | 3 | from solana.rpc.async_api import AsyncClient 4 | from solana.rpc.commitment import Commitment, Confirmed 5 | from solana.rpc.types import TxOpts 6 | from solders.address_lookup_table_account import AddressLookupTableAccount 7 | from solders.hash import Hash 8 | from solders.instruction import Instruction 9 | from solders.keypair import Keypair 10 | from solders.message import MessageV0 11 | from solders.rpc.responses import SendTransactionResp 12 | from solders.transaction import VersionedTransaction 13 | 14 | from driftpy.tx.types import TxSender, TxSigAndSlot 15 | 16 | 17 | class StandardTxSender(TxSender): 18 | def __init__( 19 | self, 20 | connection: AsyncClient, 21 | opts: TxOpts, 22 | blockhash_commitment: Commitment = Confirmed, 23 | ): 24 | self.connection = connection 25 | if opts.skip_confirmation: 26 | raise ValueError("RetryTxSender doesnt support skip confirmation") 27 | self.opts = opts 28 | self.blockhash_commitment = blockhash_commitment 29 | 30 | async def get_blockhash(self) -> Hash: 31 | return ( 32 | await self.connection.get_latest_blockhash(self.blockhash_commitment) 33 | ).value.blockhash 34 | 35 | async def fetch_latest_blockhash(self) -> Hash: 36 | return await self.get_blockhash() 37 | 38 | async def get_versioned_tx( 39 | self, 40 | ixs: Sequence[Instruction], 41 | payer: Keypair, 42 | lookup_tables: Sequence[AddressLookupTableAccount], 43 | additional_signers: Optional[Sequence[Keypair]] = None, 44 | ) -> VersionedTransaction: 45 | latest_blockhash = await self.fetch_latest_blockhash() 46 | 47 | msg = MessageV0.try_compile( 48 | payer.pubkey(), ixs, lookup_tables, latest_blockhash 49 | ) 50 | 51 | signers = [payer] 52 | if additional_signers is not None: 53 | [signers.append(signer) for signer in additional_signers] 54 | 55 | return VersionedTransaction(msg, signers) 56 | 57 | async def send(self, tx: VersionedTransaction) -> TxSigAndSlot: 58 | raw = bytes(tx) 59 | 60 | body = self.connection._send_raw_transaction_body(raw, self.opts) 61 | resp = await self.connection._provider.make_request(body, SendTransactionResp) 62 | 63 | if not isinstance(resp, SendTransactionResp): 64 | raise Exception(f"Unexpected response from send transaction: {resp}") 65 | 66 | sig = resp.value 67 | 68 | sig_status = await self.connection.confirm_transaction( 69 | sig, self.opts.preflight_commitment 70 | ) 71 | slot = sig_status.context.slot 72 | 73 | return TxSigAndSlot(sig, slot) 74 | 75 | async def send_no_confirm(self, tx: VersionedTransaction) -> TxSigAndSlot: 76 | raw = bytes(tx) 77 | 78 | body = self.connection._send_raw_transaction_body(raw, self.opts) 79 | resp = await self.connection._provider.make_request(body, SendTransactionResp) 80 | sig = resp.value 81 | 82 | return TxSigAndSlot(sig, 0) 83 | -------------------------------------------------------------------------------- /src/driftpy/tx/types.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Optional, Sequence 4 | 5 | from solders.address_lookup_table_account import AddressLookupTableAccount 6 | from solders.instruction import Instruction 7 | from solders.keypair import Keypair 8 | from solders.signature import Signature 9 | from solders.transaction import VersionedTransaction 10 | 11 | 12 | @dataclass 13 | class TxSigAndSlot: 14 | tx_sig: Signature 15 | slot: int 16 | 17 | 18 | class TxSender: 19 | @abstractmethod 20 | async def get_versioned_tx( 21 | self, 22 | ixs: Sequence[Instruction], 23 | payer: Keypair, 24 | lookup_tables: Sequence[AddressLookupTableAccount], 25 | additional_signers: Optional[Sequence[Keypair]], 26 | ) -> VersionedTransaction: 27 | pass 28 | 29 | @abstractmethod 30 | async def send(self, tx: VersionedTransaction) -> TxSigAndSlot: 31 | pass 32 | -------------------------------------------------------------------------------- /src/driftpy/user_map/polling_sub.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from driftpy.user_map.types import Subscription 3 | 4 | 5 | class PollingSubscription(Subscription): 6 | def __init__(self, user_map, frequency: float, skip_initial_load: bool = False): 7 | from driftpy.user_map.user_map import UserMap 8 | 9 | self.user_map: UserMap = user_map 10 | self.frequency = frequency 11 | self.skip_initial_load = skip_initial_load 12 | self.timer_task = None 13 | 14 | async def subscribe(self): 15 | if self.timer_task is not None: 16 | return 17 | 18 | if not self.skip_initial_load: 19 | await self.user_map.sync() 20 | 21 | self.timer_task = asyncio.create_task(self._polling_loop()) 22 | 23 | async def _polling_loop(self): 24 | if self.frequency == 0: 25 | # We don't want to start this loop if the frequency is zero. 26 | return 27 | while True: 28 | await asyncio.sleep(self.frequency) 29 | await self.user_map.sync() 30 | 31 | async def unsubscribe(self): 32 | if self.timer_task is not None: 33 | self.timer_task.cancel() 34 | self.timer_task = None 35 | -------------------------------------------------------------------------------- /src/driftpy/user_map/types.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from driftpy.drift_user import DriftUser 3 | from solders.pubkey import Pubkey 4 | from typing import Optional 5 | from enum import Enum 6 | 7 | 8 | class UserMapInterface(ABC): 9 | @abstractmethod 10 | async def subscribe(self) -> None: 11 | pass 12 | 13 | @abstractmethod 14 | async def unsubscribe(self) -> None: 15 | pass 16 | 17 | @abstractmethod 18 | async def add_pubkey(self, user_account_public_key: Pubkey) -> None: 19 | pass 20 | 21 | @abstractmethod 22 | def has(self, key: str) -> bool: 23 | pass 24 | 25 | @abstractmethod 26 | def get(self, key: str) -> Optional[DriftUser]: 27 | pass 28 | 29 | @abstractmethod 30 | async def must_get(self, key: str) -> DriftUser: 31 | pass 32 | 33 | @abstractmethod 34 | def get_user_authority(self, key: str) -> Optional[Pubkey]: 35 | pass 36 | 37 | @abstractmethod 38 | def values(self): 39 | pass 40 | 41 | 42 | class Subscription(ABC): 43 | pass 44 | 45 | 46 | class ConfigType(Enum): 47 | CACHED = "cached" 48 | WEBSOCKET = "websocket" 49 | POLLING = "polling" 50 | -------------------------------------------------------------------------------- /src/driftpy/user_map/user_map_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Literal, Optional, Union 3 | 4 | from solana.rpc.async_api import AsyncClient 5 | from solana.rpc.commitment import Commitment 6 | 7 | from driftpy.drift_client import DriftClient 8 | 9 | 10 | @dataclass 11 | class UserAccountFilterCriteria: 12 | # only return users that have open orders 13 | has_open_orders: bool 14 | 15 | 16 | @dataclass 17 | class PollingConfig: 18 | frequency: int 19 | commitment: Optional[Commitment] = None 20 | 21 | 22 | @dataclass 23 | class WebsocketConfig: 24 | resub_timeout_ms: Optional[int] = None 25 | commitment: Optional[Commitment] = None 26 | 27 | 28 | @dataclass 29 | class UserMapConfig: 30 | drift_client: DriftClient 31 | subscription_config: Union[PollingConfig, WebsocketConfig] 32 | # connection object to use specifically for the UserMap. 33 | # If None, will use the drift_client's connection 34 | connection: Optional[AsyncClient] = None 35 | # True to skip the initial load of user_accounts via gPA 36 | skip_initial_load: Optional[bool] = False 37 | # True to include idle users when loading. 38 | # Defaults to false to decrease # of accounts subscribed to 39 | include_idle: Optional[bool] = None 40 | 41 | 42 | @dataclass 43 | class SyncConfig: 44 | type: Literal["default", "paginated"] 45 | chunk_size: Optional[int] = None 46 | concurrency_limit: Optional[int] = None 47 | 48 | 49 | @dataclass 50 | class UserStatsMapConfig: 51 | drift_client: DriftClient 52 | connection: Optional[AsyncClient] = None 53 | sync_config: Optional[SyncConfig] = None 54 | -------------------------------------------------------------------------------- /src/driftpy/user_map/websocket_sub.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, TypeVar 2 | from driftpy.accounts.ws.program_account_subscriber import ( 3 | WebSocketProgramAccountSubscriber, 4 | ) 5 | from driftpy.memcmp import get_user_filter, get_non_idle_user_filter 6 | from driftpy.accounts.types import UpdateCallback, WebsocketProgramAccountOptions 7 | 8 | T = TypeVar("T") 9 | 10 | 11 | class WebsocketSubscription: 12 | def __init__( 13 | self, 14 | user_map, 15 | commitment, 16 | on_update: UpdateCallback, 17 | skip_initial_load: bool = False, 18 | resub_timeout_ms: int = None, 19 | include_idle: bool = False, 20 | decode: Optional[Callable[[bytes], T]] = None, 21 | ): 22 | from driftpy.user_map.user_map import UserMap 23 | 24 | self.user_map: UserMap = user_map 25 | self.commitment = commitment 26 | self.on_update = on_update 27 | self.skip_initial_load = skip_initial_load 28 | self.resub_timeout_ms = resub_timeout_ms 29 | self.include_idle = include_idle 30 | self.subscriber = None 31 | self.decode = decode 32 | 33 | async def subscribe(self): 34 | if not self.subscriber: 35 | filters = (get_user_filter(),) 36 | if not self.include_idle: 37 | filters += (get_non_idle_user_filter(),) 38 | options = WebsocketProgramAccountOptions(filters, self.commitment, "base64") 39 | self.subscriber = WebSocketProgramAccountSubscriber( 40 | "UserMap", 41 | self.user_map.drift_client.program, 42 | options, 43 | self.on_update, 44 | self.decode, 45 | ) 46 | 47 | await self.subscriber.subscribe() 48 | 49 | if not self.skip_initial_load: 50 | await self.user_map.sync() 51 | 52 | async def unsubscribe(self): 53 | if not self.subscriber: 54 | return 55 | await self.subscriber.unsubscribe() 56 | self.subscriber = None 57 | -------------------------------------------------------------------------------- /src/driftpy/vaults/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some simple helpers for interacting with the vaults program. 3 | 4 | For a complete vaults SDK, please see https://github.com/drift-labs/drift-vaults 5 | """ 6 | 7 | from driftpy.vaults.helpers import ( 8 | fetch_all_vault_depositors, 9 | filter_vault_depositors, 10 | get_all_vaults, 11 | get_depositor_info, 12 | get_vault_by_name, 13 | get_vault_depositors, 14 | get_vault_stats, 15 | get_vaults_program, 16 | ) 17 | from driftpy.vaults.vault_client import VaultClient 18 | 19 | __all__ = [ 20 | "get_vaults_program", 21 | "get_all_vaults", 22 | "get_vault_by_name", 23 | "get_vault_depositors", 24 | "get_vault_stats", 25 | "get_depositor_info", 26 | "filter_vault_depositors", 27 | "fetch_all_vault_depositors", 28 | "VaultClient", 29 | ] 30 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/tests/__init__.py -------------------------------------------------------------------------------- /tests/ci/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/tests/ci/__init__.py -------------------------------------------------------------------------------- /tests/decode/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/tests/decode/__init__.py -------------------------------------------------------------------------------- /tests/decode/decode_stat.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import time 3 | 4 | from pathlib import Path 5 | from pytest import fixture, mark 6 | 7 | from anchorpy import Idl, Program 8 | 9 | from solders.pubkey import Pubkey 10 | 11 | import driftpy 12 | from driftpy.types import UserStatsAccount 13 | from driftpy.decode.user_stat import decode_user_stat 14 | 15 | from tests.decode.stat_decode_strings import stats 16 | 17 | 18 | @fixture(scope="session") 19 | def program() -> Program: 20 | file = Path(str(driftpy.__path__[0]) + "/idl/drift.json") 21 | with file.open() as f: 22 | raw = file.read_text() 23 | idl = Idl.from_json(raw) 24 | return Program( 25 | idl, 26 | Pubkey.from_string("dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH"), 27 | ) 28 | 29 | 30 | @mark.asyncio 31 | async def test_user_stat_decode(program: Program): 32 | total_anchor_time: int = 0 33 | total_custom_time: int = 0 34 | 35 | for index, stat in enumerate(stats): 36 | stat_bytes = base64.b64decode(stat) 37 | (anchor_time, custom_time) = user_stats_decode(program, stat_bytes, index) 38 | total_anchor_time += anchor_time 39 | total_custom_time += custom_time 40 | 41 | print("Total anchor time:", total_anchor_time) 42 | print("Total custom time:", total_custom_time) 43 | 44 | 45 | def user_stats_decode(program: Program, buffer: bytes, index: int): 46 | print("Benchmarking user stats decode: ", index) 47 | 48 | anchor_start_ts = int(time.time() * 1_000) 49 | anchor_user_stats: UserStatsAccount = program.coder.accounts.decode(buffer) 50 | anchor_end_ts = int(time.time() * 1_000) 51 | anchor_time = anchor_end_ts - anchor_start_ts 52 | 53 | custom_start_ts = int(time.time() * 1_000) 54 | custom_user_stats = decode_user_stat(buffer) 55 | custom_end_ts = int(time.time() * 1_000) 56 | custom_time = custom_end_ts - custom_start_ts 57 | 58 | assert str(anchor_user_stats.authority) == str(custom_user_stats.authority) 59 | assert str(anchor_user_stats.referrer) == str(custom_user_stats.referrer) 60 | assert ( 61 | anchor_user_stats.fees.total_fee_paid == custom_user_stats.fees.total_fee_paid 62 | ) 63 | assert ( 64 | anchor_user_stats.fees.total_fee_rebate 65 | == custom_user_stats.fees.total_fee_rebate 66 | ) 67 | assert ( 68 | anchor_user_stats.fees.total_token_discount 69 | == custom_user_stats.fees.total_token_discount 70 | ) 71 | assert ( 72 | anchor_user_stats.fees.total_referee_discount 73 | == custom_user_stats.fees.total_referee_discount 74 | ) 75 | assert ( 76 | anchor_user_stats.fees.total_referrer_reward 77 | == custom_user_stats.fees.total_referrer_reward 78 | ) 79 | assert ( 80 | anchor_user_stats.fees.current_epoch_referrer_reward 81 | == custom_user_stats.fees.current_epoch_referrer_reward 82 | ) 83 | assert anchor_user_stats.next_epoch_ts == custom_user_stats.next_epoch_ts 84 | assert anchor_user_stats.maker_volume30d == custom_user_stats.maker_volume30d 85 | assert anchor_user_stats.taker_volume30d == custom_user_stats.taker_volume30d 86 | assert anchor_user_stats.filler_volume30d == custom_user_stats.filler_volume30d 87 | assert ( 88 | anchor_user_stats.last_maker_volume30d_ts 89 | == custom_user_stats.last_maker_volume30d_ts 90 | ) 91 | assert ( 92 | anchor_user_stats.last_taker_volume30d_ts 93 | == custom_user_stats.last_taker_volume30d_ts 94 | ) 95 | assert ( 96 | anchor_user_stats.last_filler_volume30d_ts 97 | == custom_user_stats.last_filler_volume30d_ts 98 | ) 99 | assert ( 100 | anchor_user_stats.if_staked_quote_asset_amount 101 | == custom_user_stats.if_staked_quote_asset_amount 102 | ) 103 | assert ( 104 | anchor_user_stats.number_of_sub_accounts 105 | == custom_user_stats.number_of_sub_accounts 106 | ) 107 | assert ( 108 | anchor_user_stats.number_of_sub_accounts_created 109 | == custom_user_stats.number_of_sub_accounts_created 110 | ) 111 | assert anchor_user_stats.is_referrer == custom_user_stats.is_referrer 112 | assert ( 113 | anchor_user_stats.disable_update_perp_bid_ask_twap 114 | == custom_user_stats.disable_update_perp_bid_ask_twap 115 | ) 116 | 117 | return (anchor_time, custom_time) 118 | -------------------------------------------------------------------------------- /tests/decode/dlob_test_helpers.py: -------------------------------------------------------------------------------- 1 | from solders.pubkey import Pubkey 2 | 3 | from typing import Optional 4 | from driftpy.dlob.dlob import DLOB 5 | 6 | from driftpy.types import ( 7 | MarketType, 8 | Order, 9 | OrderStatus, 10 | OrderTriggerCondition, 11 | OrderType, 12 | PositionDirection, 13 | ) 14 | 15 | 16 | def insert_order_to_dlob( 17 | dlob: DLOB, 18 | user_account: Pubkey, 19 | order_type: OrderType, 20 | market_type: MarketType, 21 | order_id: int, 22 | market_index: int, 23 | price: int, 24 | base_asset_amount: int, 25 | direction: PositionDirection, 26 | auction_start_price: int, 27 | auction_end_price: int, 28 | slot: Optional[int] = None, 29 | max_ts=0, 30 | oracle_price_offset=0, 31 | post_only=False, 32 | auction_duration=10, 33 | ): 34 | slot = slot if slot is not None else 1 35 | order = Order( 36 | slot, 37 | price, 38 | base_asset_amount, 39 | 0, 40 | 0, 41 | 0, 42 | auction_start_price, 43 | auction_end_price, 44 | max_ts, 45 | oracle_price_offset, 46 | order_id, 47 | market_index, 48 | OrderStatus.Open(), 49 | order_type, 50 | market_type, 51 | 0, 52 | PositionDirection.Long(), 53 | direction, 54 | False, 55 | post_only, 56 | False, 57 | OrderTriggerCondition.Above(), 58 | auction_duration, 59 | [0, 0, 0], 60 | ) 61 | dlob.insert_order(order, user_account, slot) 62 | 63 | 64 | def insert_trigger_order_to_dlob( 65 | dlob: DLOB, 66 | user_account: Pubkey, 67 | order_type: OrderType, 68 | market_type: MarketType, 69 | order_id: int, 70 | market_index: int, 71 | price: int, 72 | base_asset_amount: int, 73 | direction: PositionDirection, 74 | trigger_price: int, 75 | trigger_condition: OrderTriggerCondition, 76 | auction_start_price: int, 77 | auction_end_price: int, 78 | slot: Optional[int] = None, 79 | max_ts=0, 80 | oracle_price_offset=0, 81 | ): 82 | slot = slot or 1 83 | order = Order( 84 | slot, 85 | price, 86 | base_asset_amount, 87 | 0, 88 | 0, 89 | trigger_price, 90 | auction_start_price, 91 | auction_end_price, 92 | max_ts, 93 | oracle_price_offset, 94 | order_id, 95 | market_index, 96 | OrderStatus.Open(), 97 | order_type, 98 | market_type, 99 | 0, 100 | PositionDirection.Long(), 101 | direction, 102 | False, 103 | False, 104 | True, 105 | trigger_condition, 106 | max_ts, 107 | [0, 0, 0], 108 | ) 109 | dlob.insert_order(order, user_account, slot) 110 | -------------------------------------------------------------------------------- /tests/dlob/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/tests/dlob/__init__.py -------------------------------------------------------------------------------- /tests/integration/events_parser.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | 3 | from solana.rpc.async_api import AsyncClient 4 | from solders.signature import Signature 5 | from anchorpy import Wallet 6 | 7 | from driftpy.events.event_subscriber import EventSubscriber 8 | from driftpy.drift_client import DriftClient 9 | 10 | 11 | @mark.asyncio 12 | async def test_events_parser(): 13 | connection = AsyncClient("https://api.mainnet-beta.solana.com") 14 | drift_client = DriftClient(connection, Wallet.dummy()) 15 | 16 | event_subscriber = EventSubscriber(drift_client.connection, drift_client.program) 17 | 18 | tx = await connection.get_transaction( 19 | Signature.from_string( 20 | "3JRzMVquzXXmbV7cPMxiMRQp25scFFkYpsntWjMJQv3i4sMZyVXAGi6X2vAHNkqH1mkNtLvpp4oT6iorzZgLkYNY" 21 | ), 22 | max_supported_transaction_version=0, 23 | ) 24 | logs = tx.value.transaction.meta.log_messages 25 | 26 | events = event_subscriber.parse_events_from_logs( 27 | Signature.from_string( 28 | "3JRzMVquzXXmbV7cPMxiMRQp25scFFkYpsntWjMJQv3i4sMZyVXAGi6X2vAHNkqH1mkNtLvpp4oT6iorzZgLkYNY" 29 | ), 30 | tx.value.slot, 31 | logs, 32 | ) 33 | 34 | print(events) 35 | -------------------------------------------------------------------------------- /tests/integration/swb_on_demand.py: -------------------------------------------------------------------------------- 1 | from pytest import mark, approx 2 | 3 | from solana.rpc.async_api import AsyncClient 4 | 5 | from solders.pubkey import Pubkey 6 | 7 | from driftpy.accounts.oracle import ( 8 | decode_oracle, 9 | get_oracle_decode_fn, 10 | get_oracle_price_data_and_slot, 11 | SWB_ON_DEMAND_CODER, 12 | ) 13 | from driftpy.accounts.types import OracleSource 14 | from driftpy.constants.numeric_constants import SWB_PRECISION 15 | 16 | 17 | @mark.asyncio 18 | async def test_swb_on_demand(): 19 | oracle = Pubkey.from_string("EZLBfnznMYKjFmaWYMEdhwnkiQF1WiP9jjTY6M8HpmGE") 20 | oracle_source = OracleSource.SwitchboardOnDemand() 21 | connection = AsyncClient("https://api.mainnet-beta.solana.com") 22 | 23 | oracle_fetched = await get_oracle_price_data_and_slot( 24 | connection, oracle, oracle_source 25 | ) 26 | 27 | raw = (await connection.get_account_info(oracle)).value.data 28 | oracle_unstructured = SWB_ON_DEMAND_CODER.accounts.decode(raw).result 29 | 30 | decode = get_oracle_decode_fn(oracle_source) 31 | oracle_decode_fn = decode(raw) 32 | 33 | oracle_decode_oracle = decode_oracle(raw, oracle_source) 34 | 35 | # these two should be identical 36 | assert oracle_decode_oracle.price == oracle_decode_fn.price 37 | assert oracle_decode_oracle.slot == oracle_decode_fn.slot 38 | assert oracle_decode_oracle.confidence == oracle_decode_fn.confidence 39 | assert oracle_decode_oracle.twap == oracle_decode_fn.twap 40 | assert oracle_decode_oracle.twap_confidence == oracle_decode_fn.twap_confidence 41 | assert ( 42 | oracle_decode_oracle.has_sufficient_number_of_data_points 43 | == oracle_decode_fn.has_sufficient_number_of_data_points 44 | ) 45 | 46 | # potential slight diff from slot drift 47 | assert oracle_fetched.data.price == approx(oracle_decode_oracle.price) 48 | assert oracle_fetched.data.slot == approx(oracle_decode_oracle.slot) 49 | assert oracle_fetched.data.confidence == approx(oracle_decode_oracle.confidence) 50 | assert oracle_fetched.data.twap == oracle_decode_oracle.twap 51 | assert oracle_fetched.data.twap_confidence == oracle_decode_oracle.twap_confidence 52 | assert ( 53 | oracle_fetched.data.has_sufficient_number_of_data_points 54 | == oracle_decode_oracle.has_sufficient_number_of_data_points 55 | ) 56 | 57 | assert oracle_unstructured.value == approx( 58 | oracle_fetched.data.price * SWB_PRECISION 59 | ) 60 | assert oracle_unstructured.slot == approx(oracle_fetched.data.slot) 61 | assert oracle_unstructured.range == approx( 62 | oracle_fetched.data.confidence * SWB_PRECISION 63 | ) 64 | 65 | assert (oracle_unstructured.value // SWB_PRECISION) == oracle_decode_oracle.price 66 | assert oracle_unstructured.slot == oracle_decode_oracle.slot 67 | assert ( 68 | oracle_unstructured.range // SWB_PRECISION 69 | ) == oracle_decode_oracle.confidence 70 | -------------------------------------------------------------------------------- /tests/math/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/tests/math/__init__.py -------------------------------------------------------------------------------- /tests/math/auction.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | 3 | from driftpy.math.auction import derive_oracle_auction_params 4 | from driftpy.types import PositionDirection 5 | from driftpy.constants.numeric_constants import PRICE_PRECISION 6 | 7 | 8 | @mark.asyncio 9 | async def test_drive_oracle_auction_params(): 10 | oracle_price = 100 * PRICE_PRECISION 11 | auction_start_price = 90 * PRICE_PRECISION 12 | auction_end_price = 110 * PRICE_PRECISION 13 | limit_price = 120 * PRICE_PRECISION 14 | 15 | oracle_order_params = derive_oracle_auction_params( 16 | PositionDirection.Long(), 17 | oracle_price, 18 | auction_start_price, 19 | auction_end_price, 20 | limit_price, 21 | ) 22 | 23 | assert oracle_order_params[0] == -10 * PRICE_PRECISION 24 | assert oracle_order_params[1] == 10 * PRICE_PRECISION 25 | assert oracle_order_params[2] == 20 * PRICE_PRECISION 26 | 27 | oracle_order_params = derive_oracle_auction_params( 28 | PositionDirection.Long(), oracle_price, oracle_price, oracle_price, oracle_price 29 | ) 30 | 31 | assert oracle_order_params[0] == 0 32 | assert oracle_order_params[1] == 0 33 | assert oracle_order_params[2] == 1 34 | 35 | oracle_price = 100 * PRICE_PRECISION 36 | auction_start_price = 110 * PRICE_PRECISION 37 | auction_end_price = 90 * PRICE_PRECISION 38 | limit_price = 80 * PRICE_PRECISION 39 | 40 | oracle_order_params = derive_oracle_auction_params( 41 | PositionDirection.Short(), 42 | oracle_price, 43 | auction_start_price, 44 | auction_end_price, 45 | limit_price, 46 | ) 47 | 48 | assert oracle_order_params[0] == 10 * PRICE_PRECISION 49 | assert oracle_order_params[1] == -10 * PRICE_PRECISION 50 | assert oracle_order_params[2] == -20 * PRICE_PRECISION 51 | 52 | oracle_order_params = derive_oracle_auction_params( 53 | PositionDirection.Short(), 54 | oracle_price, 55 | oracle_price, 56 | oracle_price, 57 | oracle_price, 58 | ) 59 | 60 | assert oracle_order_params[0] == 0 61 | assert oracle_order_params[1] == 0 62 | assert oracle_order_params[2] == -1 63 | -------------------------------------------------------------------------------- /tests/math/insurance.py: -------------------------------------------------------------------------------- 1 | from pytest import mark 2 | 3 | from driftpy.math.utils import time_remaining_until_update 4 | 5 | 6 | @mark.asyncio 7 | async def test_time_remaining_updates(): 8 | now = 1_683_576_852 9 | last_update = 1_683_576_000 10 | period = 3_600 11 | 12 | tr = time_remaining_until_update(now, last_update, period) 13 | assert tr == 2_748 14 | 15 | tr = time_remaining_until_update(now, last_update - period, period) 16 | assert tr == 0 17 | 18 | too_late = last_update - ((period // 3) + 1) 19 | tr = time_remaining_until_update(too_late + 1, too_late, period) 20 | assert tr == 4_800 21 | 22 | tr = time_remaining_until_update(now, last_update + 1, period) 23 | assert tr == 2_748 24 | 25 | tr = time_remaining_until_update(now, last_update - 1, period) 26 | assert tr == 2_748 27 | -------------------------------------------------------------------------------- /update_idl.sh: -------------------------------------------------------------------------------- 1 | git submodule update --remote --merge --recursive && 2 | cd protocol-v2/ && 3 | anchor build && 4 | cp target/idl/drift.json ../src/driftpy/idl/drift.json --------------------------------------------------------------------------------