├── .bumpversion.cfg
├── .gitattributes
├── .github
└── workflows
│ ├── docs.yml
│ ├── main.yml
│ ├── release.yml
│ └── tests.yml
├── .gitignore
├── .gitmodules
├── .vscode
└── settings.json
├── CHANGELOG.md
├── README.md
├── docs
├── accounts.md
├── addresses.md
├── clearing_house.md
├── clearing_house_user.md
├── css
│ ├── custom.css
│ └── mkdocstrings.css
├── img
│ └── drift.png
└── index.md
├── examples
├── authority_user-account_lookup.py
├── check_pyth_lazer.py
├── deposit_and_withdraw.py
├── driftpy-enhanced-usermap.py
├── driftpy-marketmap-details.py
├── dump_pickle.py
├── example_grpc_log_provider.py
├── fetch_all_markets.py
├── floating_maker.py
├── get_all_users_pnl.py
├── get_borrows.py
├── get_deposits.py
├── get_funding.py
├── get_vault_metrics.py
├── grpc_auction.py
├── grpc_client.py
├── high_leverage_users.py
├── if_stake.py
├── jit_maker_grpc.py
├── limit_order_grid.py
├── minimal_maker.py
├── oracle_maker.py
├── orders_subcribe.py
├── place_and_take.py
├── pmm_users.py
├── protected_order.py
├── readme.md
├── send_indicative_quotes.py
├── settle_pnl.py
├── signed_msg_subscribe.py
├── spot_market_trade.py
├── start_lp.py
├── swift_maker.py
├── swift_take_place_and_make.py
├── swift_taker.py
├── swift_taker_delegate.py
├── view.py
└── ws_simple.py
├── mkdocs.yml
├── poetry.lock
├── poetry.toml
├── pyproject.toml
├── requirements.txt
├── scripts
├── bump.py
├── ci.sh
├── decode.sh
├── dlob.sh
├── generate_constants.py
└── math.sh
├── setup.sh
├── src
└── driftpy
│ ├── __init__.py
│ ├── account_subscription_config.py
│ ├── accounts
│ ├── __init__.py
│ ├── bulk_account_loader.py
│ ├── cache
│ │ ├── __init__.py
│ │ ├── drift_client.py
│ │ └── user.py
│ ├── demo
│ │ ├── __init__.py
│ │ ├── drift_client.py
│ │ └── user.py
│ ├── get_accounts.py
│ ├── grpc
│ │ ├── account_subscriber.py
│ │ ├── drift_client.py
│ │ ├── geyser_codegen
│ │ │ ├── README.md
│ │ │ ├── geyser_pb2.py
│ │ │ ├── geyser_pb2.pyi
│ │ │ ├── geyser_pb2_grpc.py
│ │ │ ├── solana_storage_pb2.py
│ │ │ ├── solana_storage_pb2.pyi
│ │ │ ├── solana_storage_pb2_grpc.py
│ │ │ └── subscribe_geyser.py
│ │ ├── program_account_subscriber.py
│ │ ├── user.py
│ │ └── user_stats.py
│ ├── oracle.py
│ ├── polling
│ │ ├── __init__.py
│ │ ├── drift_client.py
│ │ └── user.py
│ ├── types.py
│ └── ws
│ │ ├── __init__.py
│ │ ├── account_subscriber.py
│ │ ├── drift_client.py
│ │ ├── program_account_subscriber.py
│ │ ├── user.py
│ │ └── user_stats.py
│ ├── address_lookup_table.py
│ ├── addresses.py
│ ├── admin.py
│ ├── auction_subscriber
│ ├── auction_subscriber.py
│ ├── grpc_auction_subscriber.py
│ └── types.py
│ ├── constants
│ ├── __init__.py
│ ├── config.py
│ ├── numeric_constants.py
│ ├── perp_markets.py
│ └── spot_markets.py
│ ├── decode
│ ├── pull_oracle.py
│ ├── signed_msg_order.py
│ ├── user.py
│ ├── user_stat.py
│ └── utils.py
│ ├── dlob
│ ├── client_types.py
│ ├── dlob.py
│ ├── dlob_helpers.py
│ ├── dlob_node.py
│ ├── dlob_subscriber.py
│ ├── node_list.py
│ └── orderbook_levels.py
│ ├── drift_client.py
│ ├── drift_user.py
│ ├── drift_user_stats.py
│ ├── events
│ ├── __init__.py
│ ├── event_list.py
│ ├── event_subscriber.py
│ ├── fetch_logs.py
│ ├── grpc_log_provider.py
│ ├── parse.py
│ ├── polling_log_provider.py
│ ├── sort.py
│ ├── tx_event_cache.py
│ ├── types.py
│ └── websocket_log_provider.py
│ ├── idl
│ ├── __init__.py
│ ├── drift.json
│ ├── drift_vaults.json
│ ├── pyth.json
│ ├── sequence_enforcer.json
│ ├── switchboard.json
│ ├── switchboard_on_demand.json
│ └── token_faucet.json
│ ├── indicative_quotes
│ ├── __init__.py
│ └── indicative_quotes_sender.py
│ ├── keypair.py
│ ├── market_map
│ ├── grpc_market_map.py
│ ├── grpc_sub.py
│ ├── market_map.py
│ ├── market_map_config.py
│ └── websocket_sub.py
│ ├── math
│ ├── amm.py
│ ├── auction.py
│ ├── conversion.py
│ ├── exchange_status.py
│ ├── fuel.py
│ ├── funding.py
│ ├── margin.py
│ ├── market.py
│ ├── oracles.py
│ ├── orders.py
│ ├── perp_position.py
│ ├── repeg.py
│ ├── spot_balance.py
│ ├── spot_market.py
│ ├── spot_position.py
│ ├── user_status.py
│ └── utils.py
│ ├── memcmp.py
│ ├── name.py
│ ├── oracles
│ ├── oracle_id.py
│ └── strict_oracle_price.py
│ ├── pickle
│ └── vat.py
│ ├── priority_fees
│ └── priority_fee_subscriber.py
│ ├── py.typed
│ ├── setup
│ └── helpers.py
│ ├── slot
│ └── slot_subscriber.py
│ ├── swift
│ ├── create_verify_ix.py
│ ├── order_subscriber.py
│ └── util.py
│ ├── tx
│ ├── __init__.py
│ ├── fast_tx_sender.py
│ ├── jito_subscriber.py
│ ├── jito_tx_sender.py
│ ├── standard_tx_sender.py
│ └── types.py
│ ├── types.py
│ ├── user_map
│ ├── polling_sub.py
│ ├── types.py
│ ├── user_map.py
│ ├── user_map_config.py
│ ├── userstats_map.py
│ └── websocket_sub.py
│ └── vaults
│ ├── __init__.py
│ ├── helpers.py
│ └── vault_client.py
├── tests
├── __init__.py
├── ci
│ ├── __init__.py
│ ├── devnet.py
│ └── mainnet.py
├── decode
│ ├── __init__.py
│ ├── decode.py
│ ├── decode_stat.py
│ ├── decode_strings.py
│ ├── dlob_test_helpers.py
│ └── stat_decode_strings.py
├── dlob
│ ├── __init__.py
│ └── dlob.py
├── dlob_test_constants.py
├── integration
│ ├── events_parser.py
│ ├── liq.py
│ ├── oracle.py
│ ├── prelaunch.py
│ ├── swb_on_demand.py
│ └── test_oracle_diff_sources.py
└── math
│ ├── __init__.py
│ ├── amm.py
│ ├── auction.py
│ ├── funding.py
│ ├── helpers.py
│ ├── insurance.py
│ ├── spot.py
│ ├── spreads.py
│ └── user.py
└── update_idl.sh
/.bumpversion.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | current_version = 0.8.56
3 | commit = True
4 | tag = True
5 | tag_name = {new_version}
6 |
7 | [bumpversion:file:pyproject.toml]
8 | search = version = "{current_version}"
9 | replace = version = "{new_version}"
10 |
11 | [bumpversion:file:src/driftpy/__init__.py]
12 | search = __version__ = "{current_version}"
13 | replace = __version__ = "{new_version}"
14 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | drift-core/** linguist-vendored
2 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Docs
2 | on:
3 | release:
4 | types: [published]
5 | jobs:
6 | docs:
7 | name: Deploy docs
8 | runs-on: ubuntu-latest
9 | steps:
10 | - name: Checkout
11 | uses: actions/checkout@v2.5.0
12 |
13 | - name: Set up Python
14 | uses: actions/setup-python@v4.3.0
15 | with:
16 | python-version: '3.10'
17 |
18 | - name: Install dependencies
19 | run: pip install mkdocs mkdocstrings-python mkdocs-material driftpy
20 |
21 | - name: Deploy docs
22 | run: mkdocs gh-deploy --force
23 |
--------------------------------------------------------------------------------
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 | on:
3 | push:
4 | branches:
5 | - master
6 | pull_request:
7 | branches: [master]
8 |
9 | defaults:
10 | run:
11 | shell: bash
12 | working-directory: .
13 |
14 | jobs:
15 | tests:
16 | runs-on: ubicloud
17 | steps:
18 | - uses: actions/checkout@v3
19 | - name: Set up Python 3.10
20 | uses: actions/setup-python@v4
21 | with:
22 | python-version: "3.10.10"
23 | - name: Install and configure Poetry
24 | uses: snok/install-poetry@v1.3.3
25 | with:
26 | version: 1.4.2
27 | virtualenvs-create: true
28 | virtualenvs-in-project: true
29 | installer-parallel: true
30 | - name: Install dependencies
31 | run: poetry install
32 | - name: Install pytest
33 | run: poetry run pip install pytest
34 | - name: Install ruff
35 | run: poetry run pip install ruff
36 | - name: Run tests
37 | env:
38 | MAINNET_RPC_ENDPOINT: ${{ secrets.MAINNET_RPC_ENDPOINT }}
39 | DEVNET_RPC_ENDPOINT: ${{ secrets.DEVNET_RPC_ENDPOINT }}
40 | # run: poetry run ruff format --check . && poetry run bash scripts/ci.sh
41 | run: poetry run bash scripts/ci.sh
42 |
43 | bump-version:
44 | runs-on: ubicloud
45 | needs: [tests]
46 | if: github.event_name == 'push' && github.ref == 'refs/heads/master'
47 | steps:
48 | - uses: actions/checkout@v3
49 | with:
50 | fetch-depth: 0
51 | - name: Set up Python 3.10
52 | uses: actions/setup-python@v4
53 | with:
54 | python-version: "3.10.10"
55 | - name: Run version bump script
56 | run: python scripts/bump.py
57 | - name: Commit changes
58 | run: |
59 | git config --local user.email "github-actions[bot]@users.noreply.github.com"
60 | git config --local user.name "github-actions[bot]"
61 | git add pyproject.toml src/driftpy/__init__.py .bumpversion.cfg
62 | git commit -m "Bump version [skip ci]"
63 | - name: Push changes
64 | uses: ad-m/github-push-action@master
65 | with:
66 | github_token: ${{ secrets.GITHUB_TOKEN }}
67 | branch: ${{ github.ref }}
68 |
69 | release:
70 | runs-on: ubicloud
71 | needs: [bump-version]
72 | if: github.event_name == 'push' && github.ref == 'refs/heads/master'
73 | steps:
74 | - name: Checkout
75 | uses: actions/checkout@v3
76 | with:
77 | fetch-depth: 0
78 | - name: Pull Latest Changes
79 | run: |
80 | git config --local user.email "github-actions[bot]@users.noreply.github.com"
81 | git config --local user.name "github-actions[bot]"
82 | git pull
83 | - name: Set up Python
84 | uses: actions/setup-python@v4
85 | with:
86 | python-version: "3.10.10"
87 | - name: Install and configure Poetry
88 | uses: snok/install-poetry@v1.3.3
89 | with:
90 | version: 1.4.2
91 | virtualenvs-create: true
92 | virtualenvs-in-project: true
93 | installer-parallel: true
94 | - name: Build package
95 | run: poetry build
96 | - name: Publish to PyPI
97 | run: poetry publish --username=__token__ --password=${{ secrets.PYPI_TOKEN }}
98 | - name: Get version
99 | id: get_version
100 | run: echo "VERSION=$(poetry version -s)" >> $GITHUB_OUTPUT
101 | - name: Create GitHub Release
102 | env:
103 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
104 | run: |
105 | gh release create v${{ steps.get_version.outputs.VERSION }} \
106 | --title "Release v${{ steps.get_version.outputs.VERSION }}" \
107 | --generate-notes \
108 | dist/*.whl dist/*.tar.gz
109 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 | on:
3 | release:
4 | types: [published]
5 | jobs:
6 | release:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - name: Checkout
10 | uses: actions/checkout@v2.5.0
11 |
12 | - name: Set up Python
13 | uses: actions/setup-python@v4.3.0
14 | with:
15 | python-version: '3.10.10'
16 | #----------------------------------------------
17 | # ----- install & configure poetry -----
18 | #----------------------------------------------
19 | - name: Install and configure Poetry
20 | uses: snok/install-poetry@v1.3.3
21 | with:
22 | version: 1.4.2
23 | virtualenvs-create: true
24 | virtualenvs-in-project: true
25 | installer-parallel: true
26 | - run: poetry build
27 | - run: poetry publish --username=__token__ --password=${{ secrets.PYPI_TOKEN }}
28 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 | on:
3 | push:
4 | branches: [master]
5 | pull_request:
6 | branches: [master]
7 |
8 | env:
9 | solana_verion: 1.14.7
10 | anchor_version: 0.26.0
11 |
12 | jobs:
13 | tests:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - name: Checkout repo.
17 | uses: actions/checkout@v2
18 | with:
19 | submodules: 'recursive'
20 |
21 | - name: Cache Solana Tool Suite
22 | uses: actions/cache@v2
23 | id: cache-solana
24 | with:
25 | path: |
26 | ~/.cache/solana/
27 | ~/.local/share/solana/
28 | key: solana-${{ runner.os }}-v0000-${{ env.solana_verion }}
29 |
30 | - name: Install Rust toolchain
31 | uses: actions-rs/toolchain@v1
32 | with:
33 | profile: minimal
34 | toolchain: nightly
35 | override: true
36 |
37 | - name: Install Solana
38 | if: steps.cache-solana.outputs.cache-hit != 'true'
39 | run: sh -c "$(curl -sSfL https://release.solana.com/v${{ env.solana_verion }}/install)"
40 |
41 | - name: Add Solana to path
42 | run: echo "/home/runner/.local/share/solana/install/active_release/bin" >> $GITHUB_PATH
43 |
44 | - uses: actions/setup-node@v2
45 | with:
46 | node-version: '17'
47 |
48 | - name: install Anchor CLI
49 | run: npm install -g @project-serum/anchor-cli
50 |
51 | - name: Generate local keypair
52 | run: yes | solana-keygen new
53 |
54 | - name: Set up Python
55 | uses: actions/setup-python@v1
56 | with:
57 | python-version: 3.9
58 |
59 | #----------------------------------------------
60 | # ----- install & configure poetry -----
61 | #----------------------------------------------
62 | - name: Install and configure Poetry
63 | uses: snok/install-poetry@v1
64 | with:
65 | version: 1.1.10
66 | virtualenvs-create: true
67 | virtualenvs-in-project: true
68 | installer-parallel: true
69 | #----------------------------------------------
70 | # load cached venv if cache exists
71 | #----------------------------------------------
72 | - name: Load cached venv
73 | id: cached-poetry-dependencies
74 | uses: actions/cache@v2
75 | with:
76 | path: .venv
77 | key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
78 | #----------------------------------------------
79 | # install dependencies if cache does not exist (todo)
80 | #----------------------------------------------
81 | - name: Install dependencies
82 | # if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
83 | run: poetry install --no-interaction --no-root
84 | #----------------------------------------------
85 | # install your root project
86 | #----------------------------------------------
87 | - name: Install library
88 | run: poetry install --no-interaction
89 | #----------------------------------------------
90 | # install nox-poetry
91 | #----------------------------------------------
92 | - name: Install nox-poetry
93 | run: pip install nox-poetry
94 | #----------------------------------------------
95 | # run linters
96 | #----------------------------------------------
97 | - name: Run linters
98 | run: make lint
99 | #----------------------------------------------
100 | # run test suite
101 | #----------------------------------------------
102 | - name: Run tests
103 | run: make test
104 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .tmp
2 | .DS_Store
3 | node.txt
4 | keypairs/
5 | test-ledger/
6 |
7 | node_modules/
8 | yarn.lock
9 | package-lock.json
10 | package.json
11 | .prettierrc
12 |
13 | # Byte-compiled / optimized / DLL files
14 | __pycache__/
15 | *.py[cod]
16 | *$py.class
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | docs/.DS_Store
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | share/python-wheels/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 | MANIFEST
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .nox/
56 | .coverage
57 | .coverage.*
58 | .cache
59 | nosetests.xml
60 | coverage.xml
61 | *.cover
62 | *.py,cover
63 | .hypothesis/
64 | .pytest_cache/
65 | cover/
66 |
67 | # Translations
68 | *.mo
69 | *.pot
70 |
71 | # Django stuff:
72 | *.log
73 | local_settings.py
74 | db.sqlite3
75 | db.sqlite3-journal
76 |
77 | # Flask stuff:
78 | instance/
79 | .webassets-cache
80 |
81 | # Scrapy stuff:
82 | .scrapy
83 |
84 | # Sphinx documentation
85 | docs/_build/
86 |
87 | # PyBuilder
88 | .pybuilder/
89 | target/
90 |
91 | # Jupyter Notebook
92 | .ipynb_checkpoints
93 |
94 | # IPython
95 | profile_default/
96 | ipython_config.py
97 |
98 | # pyenv
99 | # For a library or package, you might want to ignore these files since the code is
100 | # intended to run in multiple environments; otherwise, check them in:
101 | # .python-version
102 |
103 | # pipenv
104 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
105 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
106 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
107 | # install all needed dependencies.
108 | #Pipfile.lock
109 |
110 | # poetry
111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
112 | # This is especially recommended for binary packages to ensure reproducibility, and is more
113 | # commonly ignored for libraries.
114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
115 | #poetry.lock
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | .idea/
166 |
167 | scratch
168 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "protocol-v2"]
2 | path = protocol-v2
3 | url = https://github.com/drift-labs/protocol-v2.git
4 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.formatting.provider": "black",
3 | "python.analysis.extraPaths": [
4 | "${workspaceFolder}/src/",
5 | "./venv/bin/python"
6 | ]
7 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DriftPy
2 |
3 |
4 |

5 |
6 |
7 | DriftPy is the Python client for the [Drift](https://www.drift.trade/) protocol.
8 | It allows you to trade and fetch data from Drift using Python.
9 |
10 | **[Read the full SDK documentation here!](https://drift-labs.github.io/v2-teacher/)**
11 |
12 | ## Installation
13 |
14 | ```
15 | pip install driftpy
16 | ```
17 |
18 | Note: requires Python >= 3.10.
19 |
20 |
21 | ## SDK Examples
22 |
23 | - `examples/` folder includes more examples of how to use the SDK including how to provide liquidity/become an lp, stake in the insurance fund, etc.
24 |
25 |
26 | ## Note on using QuickNode
27 |
28 | If you are using QuickNode free plan, you *must* use `AccountSubscriptionConfig("demo")`, and you can only subscribe to 1 perp market and 1 spot market at a time.
29 |
30 | Non-QuickNode free RPCs (including the public mainnet-beta url) can use `cached` as well.
31 |
32 | Example setup for `AccountSubscriptionConfig("demo")`:
33 |
34 | ```python
35 | # This example will listen to perp markets 0 & 1 and spot market 0
36 | # If you are listening to any perp markets, you must listen to spot market 0 or the SDK will break
37 |
38 | perp_markets = [0, 1]
39 | spot_market_oracle_infos, perp_market_oracle_infos, spot_market_indexes = get_markets_and_oracles(perp_markets = perp_markets)
40 |
41 | oracle_infos = spot_market_oracle_infos + perp_market_oracle_infos
42 |
43 | drift_client = DriftClient(
44 | connection,
45 | wallet,
46 | "mainnet",
47 | perp_market_indexes = perp_markets,
48 | spot_market_indexes = spot_market_indexes,
49 | oracle_infos = oracle_infos,
50 | account_subscription = AccountSubscriptionConfig("demo"),
51 | )
52 | await drift_client.subscribe()
53 | ```
54 | If you intend to use `AccountSubscriptionConfig("demo)`, you *must* call `get_markets_and_oracles` to get the information you need.
55 |
56 | `get_markets_and_oracles` will return all the necessary `OracleInfo`s and `market_indexes` in order to use the SDK.
57 |
58 | # Development
59 |
60 | ## Setting Up Dev Env
61 |
62 | `bash setup.sh`
63 |
64 |
65 | Ensure correct python version (using pyenv is recommended):
66 | ```bash
67 | pyenv install 3.10.11
68 | pyenv global 3.10.11
69 | poetry env use $(pyenv which python)
70 | ```
71 |
72 | Install dependencies:
73 | ```bash
74 | poetry install
75 | ```
76 |
77 | To run tests, first ensure you have set up the RPC url, then run `pytest`:
78 | ```bash
79 | export MAINNET_RPC_ENDPOINT=""
80 | export DEVNET_RPC_ENDPOINT="https://api.devnet.solana.com" # or your own RPC
81 |
82 | poetry run pytest -v -s -x tests/ci/*.py
83 | poetry run pytest -v -s tests/math/*.py
84 | ```
85 |
--------------------------------------------------------------------------------
/docs/accounts.md:
--------------------------------------------------------------------------------
1 | # Accounts
2 |
3 | These functions are used to retrieve specific on-chain accounts (State, PerpMarket, SpotMarket, etc.)
4 |
5 | ## Example
6 |
7 | ```python
8 | drift_client= DriftClient.from_config(config, provider)
9 |
10 | # get sol market info
11 | sol_market_index = 0
12 | sol_market = await get_perp_market_account(drift_client.program, sol_market_index)
13 | print(
14 | sol_market.amm.sqrt_k,
15 | sol_market.amm.base_asset_amount_long,
16 | sol_market.amm.base_asset_amount_short,
17 | )
18 |
19 | # get usdc spot market info
20 | usdc_spot_market_index = 0
21 | usdc_market = await get_spot_market_account(drift_client.program, usdc_spot_market_index)
22 | print(
23 | usdc.market_index,
24 | usdc.deposit_balance,
25 | usdc.borrow_balance,
26 | )
27 | ```
28 |
29 | :::driftpy.accounts
30 |
--------------------------------------------------------------------------------
/docs/addresses.md:
--------------------------------------------------------------------------------
1 | # Addresses
2 |
3 | These functions are used to derive on-chain addresses of the accounts (publickey of the sol-market)
4 |
5 | :::driftpy.addresses
--------------------------------------------------------------------------------
/docs/clearing_house.md:
--------------------------------------------------------------------------------
1 | # Drift Client
2 |
3 | This object is used to interact with the protocol (deposit, withdraw, trade, lp, etc.)
4 |
5 | ## Example
6 |
7 | ```python
8 | drift_client = DriftClient.from_config(config,provider)
9 | # open a 10 SOL long position
10 | sig = await drift_client.open_position(
11 | PositionDirection.LONG(), # long
12 | int(10 * BASE_PRECISION), # 10 in base precision
13 | 0, # sol market index
14 | )
15 |
16 | # mint 100 LP shares on the SOL market
17 | await drift_client.add_liquidity(
18 | int(100 * AMM_RESERVE_PRECISION),
19 | 0,
20 | )
21 | ```
22 |
23 | ## Configuration
24 |
25 | Use the `JUPITER_URL` environment variable to set the endpoint URL for the Jupiter V6 Swap API. This allows you to switch between self-hosted, paid-hosted, or other public API endpoints such as [jupiterapi.com](https://www.jupiterapi.com/) for higher rate limits and reduced latency. For more details, see the official [self-hosted](https://station.jup.ag/docs/apis/self-hosted) and [paid-hosted](https://station.jup.ag/docs/apis/self-hosted#paid-hosted-apis) documentation.
26 |
27 | :::driftpy.drift_client
28 |
--------------------------------------------------------------------------------
/docs/clearing_house_user.md:
--------------------------------------------------------------------------------
1 | # User
2 |
3 | This object is used to fetch data from the protocol and view user metrics (leverage, free collateral, etc.)
4 |
5 | ## Example
6 |
7 | ```python
8 | drift_client = DriftClient.from_config(config, provider)
9 | drift_user = User(drift_client)
10 |
11 | # inspect user's leverage
12 | leverage = await drift_user.get_leverage()
13 | print('current leverage:', leverage / 10_000)
14 |
15 | # you can also inspect other accounts information using the (authority=) flag
16 | bigz_acc = User(drift_client, authority=PublicKey('bigZ'))
17 | leverage = await bigz_acc.get_leverage()
18 | print('bigZs leverage:', leverage / 10_000)
19 |
20 | # user calls can be expensive on the rpc so we can cache them
21 | drift_user = User(drift_client, use_cache=True)
22 | await drift_user.set_cache()
23 |
24 | # works without any rpc calls (uses the cached data)
25 | upnl = await drift_user.get_unrealized_pnl(with_funding=True)
26 | print('upnl:', upnl)
27 | ```
28 |
29 | :::driftpy.drift_user
30 |
--------------------------------------------------------------------------------
/docs/css/custom.css:
--------------------------------------------------------------------------------
1 | a.external-link::after {
2 | /* \00A0 is a non-breaking space
3 | to make the mark be on the same line as the link
4 | */
5 | content: "\00A0[↪]";
6 | }
7 |
8 | a.internal-link::after {
9 | /* \00A0 is a non-breaking space
10 | to make the mark be on the same line as the link
11 | */
12 | content: "\00A0↪";
13 | }
14 |
--------------------------------------------------------------------------------
/docs/css/mkdocstrings.css:
--------------------------------------------------------------------------------
1 | /* Indentation. */
2 | div.doc-contents:not(.first) {
3 | padding-left: 25px;
4 | border-left: 4px solid rgba(230, 230, 230);
5 | margin-bottom: 80px;
6 | }
7 |
--------------------------------------------------------------------------------
/docs/img/drift.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/docs/img/drift.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Drift-v2 Python SDK
2 |
3 |
4 |

5 |
6 |
7 | DriftPy is the Python SDK for [Drift-v2](https://www.drift.trade/) on Solana.
8 | It allows you to trade and fetch data from Drift using Python.
9 |
10 | ## Installation
11 |
12 | ```
13 | pip install driftpy
14 | ```
15 |
16 | Note: requires Python >= 3.10.
17 |
18 | ## Key Components
19 |
20 | - `DriftClient` / `drift_client.py`: Used to interact with the protocol (deposit, withdraw, trade, lp, etc.)
21 | - `DriftUser` / `drift_user.py`: Used to fetch data from the protocol and view user metrics (leverage, free collateral, etc.)
22 | - `accounts.py`: Used to retrieve specific on-chain accounts (State, PerpMarket, SpotMarket, etc.)
23 | - `addresses.py`: Used to derive on-chain addresses of the accounts (publickey of the sol-market)
24 |
25 | ## Example
26 |
27 | ```python
28 |
29 | from solana.keypair import Keypair
30 | from driftpy.drift_client import DriftClient
31 | from driftpy.drift_user import DriftUser
32 | from driftpy.constants.numeric_constants import BASE_PRECISION, AMM_RESERVE_PRECISION
33 |
34 | from anchorpy import Provider, Wallet
35 | from solana.rpc.async_api import AsyncClient
36 |
37 | # load keypair from file
38 | KEYPATH = '../your-keypair-secret.json'
39 | with open(KEYPATH, 'r') as f:
40 | secret = json.load(f)
41 | kp = Keypair.from_secret_key(bytes(secret))
42 |
43 | # create clearing house for mainnet
44 | ENV = 'mainnet'
45 | config = configs[ENV]
46 | wallet = Wallet(kp)
47 | connection = AsyncClient(config.default_http)
48 | provider = Provider(connection, wallet)
49 |
50 | drift_client = DriftClient.from_config(config, provider)
51 | drift_user = DriftUser(drift_client)
52 |
53 | # open a 10 SOL long position
54 | sig = await drift_client.open_position(
55 | PositionDirection.LONG(), # long
56 | int(10 * BASE_PRECISION), # 10 in base precision
57 | 0, # sol market index
58 | )
59 |
60 | # mint 100 LP shares on the SOL market
61 | await drift_client.add_liquidity(
62 | int(100 * AMM_RESERVE_PRECISION),
63 | 0,
64 | )
65 |
66 | # inspect user's leverage
67 | leverage = await drift_user.get_leverage()
68 | print('current leverage:', leverage / 10_000)
69 |
70 | # you can also inspect other accounts information using the (authority=) flag
71 | bigz_acc = DriftUser(drift_client, user_public_key=PublicKey('bigZ'))
72 | leverage = await bigz_acc.get_leverage()
73 | print('bigZs leverage:', leverage / 10_000)
74 |
75 | # clearing house user calls can be expensive on the rpc so we can cache them
76 | drift_user = DriftUser(drift_client, use_cache=True)
77 | await drift_user.set_cache()
78 |
79 | # works without any rpc calls (uses the cached data)
80 | upnl = await drift_user.get_unrealized_pnl(with_funding=True)
81 | print('upnl:', upnl)
82 | ```
83 |
--------------------------------------------------------------------------------
/examples/check_pyth_lazer.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | from anchorpy.provider import Wallet
5 | from dotenv import load_dotenv
6 | from solana.rpc.async_api import AsyncClient
7 |
8 | from driftpy.drift_client import DriftClient
9 | from driftpy.types import is_variant
10 |
11 |
12 | async def main():
13 | load_dotenv()
14 | url = os.getenv("DEVNET_RPC_ENDPOINT", "https://api.devnet.solana.com")
15 | connection = AsyncClient(url)
16 | print("RPC URL:", url)
17 |
18 | print("Checking devnet constants")
19 | drift_client = DriftClient(
20 | connection,
21 | Wallet.dummy(),
22 | env="devnet",
23 | )
24 |
25 | print("Subscribing to Drift Client")
26 | await drift_client.subscribe()
27 | received_perp_markets = sorted(
28 | drift_client.get_perp_market_accounts(), key=lambda market: market.market_index
29 | )
30 | for market in received_perp_markets:
31 | oracle_data = drift_client.get_user().get_oracle_data_for_perp_market(
32 | market.market_index
33 | )
34 | if oracle_data and (
35 | is_variant(market.amm.oracle_source, "PythLazer")
36 | or is_variant(market.amm.oracle_source, "PythLazer1K")
37 | or is_variant(market.amm.oracle_source, "PythLazer1M")
38 | ):
39 | print(
40 | market.market_index,
41 | market.amm.oracle,
42 | bytes(market.name).decode("utf-8").strip(),
43 | market.amm.oracle_source,
44 | oracle_data.price / 10**6,
45 | )
46 |
47 | print("Subscribed to Drift Client")
48 | await drift_client.unsubscribe()
49 | await connection.close()
50 |
51 |
52 | if __name__ == "__main__":
53 | asyncio.run(main())
54 |
--------------------------------------------------------------------------------
/examples/deposit_and_withdraw.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | from anchorpy.provider import Provider, Wallet
5 | from dotenv import load_dotenv
6 | from solana.rpc.async_api import AsyncClient
7 |
8 | from driftpy.drift_client import DriftClient
9 | from driftpy.keypair import load_keypair
10 | from driftpy.types import TxParams
11 |
12 | LAMPORTS_PER_SOL = 10**9
13 |
14 |
15 | load_dotenv()
16 |
17 |
18 | async def make_spot_trade():
19 | rpc = os.environ.get("RPC_TRITON")
20 | secret = os.environ.get("PRIVATE_KEY")
21 | kp = load_keypair(secret)
22 | wallet = Wallet(kp)
23 | print(f"Using wallet: {wallet.public_key}")
24 |
25 | connection = AsyncClient(rpc)
26 | provider = Provider(connection, wallet)
27 | drift_client = DriftClient(
28 | provider.connection,
29 | provider.wallet,
30 | env="mainnet",
31 | tx_params=TxParams(compute_units_price=85_000, compute_units=1_400_000),
32 | )
33 | await drift_client.subscribe()
34 |
35 | print("Drift client subscribed")
36 |
37 | amount = int(0.4 * LAMPORTS_PER_SOL)
38 |
39 | print("Depositing 0.4")
40 | await drift_client.deposit(
41 | amount=amount,
42 | spot_market_index=1,
43 | user_token_account=drift_client.wallet.public_key,
44 | )
45 | print("Deposited 0.4")
46 |
47 | print("Withdrawing 0.4")
48 | await drift_client.withdraw(
49 | amount=amount,
50 | market_index=1,
51 | user_token_account=drift_client.wallet.public_key,
52 | )
53 | print("Withdrew 0.4")
54 |
55 |
56 | if __name__ == "__main__":
57 | asyncio.run(make_spot_trade())
58 |
--------------------------------------------------------------------------------
/examples/example_grpc_log_provider.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | from dotenv import load_dotenv
5 | from solana.rpc.commitment import Commitment
6 |
7 | from driftpy.events.grpc_log_provider import GrpcLogProvider
8 | from driftpy.events.types import GrpcLogProviderConfig
9 |
10 | load_dotenv()
11 |
12 |
13 | async def simple_log_callback(signature: str, slot: int, logs: list[str]):
14 | print(f"Received logs for tx: {signature} in slot: {slot}")
15 | print("---")
16 |
17 |
18 | async def main():
19 | GRPC_ENDPOINT = os.getenv("GRPC_ENDPOINT")
20 | AUTH_TOKEN = os.getenv("GRPC_AUTH_TOKEN")
21 | if not GRPC_ENDPOINT or not AUTH_TOKEN:
22 | raise ValueError(
23 | "GRPC_ENDPOINT and GRPC_AUTH_TOKEN must be set in the environment"
24 | )
25 |
26 | USER_ACCOUNT_TO_FILTER = "BrRpSaQ6hFDw8darPCyP9Sw7sjydMFQqB4ECAotXSEci"
27 | print(f"Attempting to connect to gRPC endpoint: {GRPC_ENDPOINT}")
28 |
29 | grpc_provider_config = GrpcLogProviderConfig(
30 | endpoint=GRPC_ENDPOINT,
31 | token=AUTH_TOKEN,
32 | )
33 |
34 | commitment = Commitment("confirmed")
35 |
36 | log_provider = GrpcLogProvider(
37 | grpc_config=grpc_provider_config,
38 | commitment=commitment,
39 | user_account_to_filter=USER_ACCOUNT_TO_FILTER,
40 | )
41 |
42 | print("Subscribing to logs...")
43 | await log_provider.subscribe(simple_log_callback)
44 |
45 | run_duration = 60
46 | print(f"Listening for logs for {run_duration} seconds...")
47 |
48 | try:
49 | for i in range(run_duration):
50 | if not log_provider.is_subscribed():
51 | print("Subscription lost unexpectedly.")
52 | break
53 | await asyncio.sleep(1)
54 | if i % 10 == 0 and i > 0:
55 | print(f"Still subscribed after {i} seconds...")
56 |
57 | except asyncio.CancelledError:
58 | print("Run cancelled.")
59 | finally:
60 | print("Unsubscribing...")
61 | await log_provider.unsubscribe()
62 | print("Example finished.")
63 |
64 |
65 | if __name__ == "__main__":
66 | try:
67 | asyncio.run(main())
68 | except KeyboardInterrupt:
69 | print("Exiting due to KeyboardInterrupt...")
70 |
--------------------------------------------------------------------------------
/examples/fetch_all_markets.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | from anchorpy import Provider, Wallet
5 | from dotenv import load_dotenv
6 | from solana.rpc.async_api import AsyncClient
7 | from solders.keypair import Keypair
8 |
9 | from driftpy.constants.numeric_constants import MARGIN_PRECISION
10 | from driftpy.drift_client import AccountSubscriptionConfig, DriftClient
11 |
12 | load_dotenv()
13 |
14 |
15 | async def get_all_market_names():
16 | rpc = os.environ.get("MAINNET_RPC_ENDPOINT")
17 | kp = Keypair() # random wallet
18 | wallet = Wallet(kp)
19 | connection = AsyncClient(rpc)
20 | provider = Provider(connection, wallet)
21 | drift_client = DriftClient(
22 | provider.connection,
23 | provider.wallet,
24 | "mainnet",
25 | account_subscription=AccountSubscriptionConfig("cached"),
26 | )
27 | await drift_client.subscribe()
28 | all_perps_markets = await drift_client.program.account["PerpMarket"].all()
29 | sorted_all_perps_markets = sorted(
30 | all_perps_markets, key=lambda x: x.account.market_index
31 | )
32 | result_perp = [
33 | bytes(x.account.name).decode("utf-8").strip() for x in sorted_all_perps_markets
34 | ]
35 | print("Perp Markets:")
36 | for index, market in enumerate(result_perp):
37 | max_leverage = get_perp_market_max_leverage(drift_client, index)
38 | print(f"{market} - {max_leverage}")
39 |
40 | result = result_perp + result_spot[1:]
41 | return result
42 |
43 |
44 | def get_perp_market_max_leverage(drift_client, market_index: int) -> float:
45 | market = drift_client.get_perp_market_account(market_index)
46 | standard_max_leverage = MARGIN_PRECISION / market.margin_ratio_initial
47 |
48 | high_leverage = (
49 | MARGIN_PRECISION / market.high_leverage_margin_ratio_initial
50 | if market.high_leverage_margin_ratio_initial > 0
51 | else 0
52 | )
53 | max_leverage = max(standard_max_leverage, high_leverage)
54 | return max_leverage
55 |
56 |
57 | if __name__ == "__main__":
58 | loop = asyncio.new_event_loop()
59 | answer = loop.run_until_complete(get_all_market_names())
60 | print(answer)
61 |
--------------------------------------------------------------------------------
/examples/get_all_users_pnl.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import csv
3 | import logging
4 | import os
5 | from datetime import datetime
6 |
7 | from anchorpy.provider import Provider, Wallet
8 | from dotenv import load_dotenv
9 | from solana.rpc.async_api import AsyncClient
10 | from solders.pubkey import Pubkey
11 | from tqdm import tqdm
12 |
13 | from driftpy.account_subscription_config import AccountSubscriptionConfig
14 | from driftpy.drift_client import DriftClient
15 | from driftpy.drift_user import DriftUser
16 | from driftpy.keypair import load_keypair
17 | from driftpy.user_map.user_map import UserMap
18 | from driftpy.user_map.user_map_config import PollingConfig, UserMapConfig
19 |
20 | logging.basicConfig(level=logging.INFO)
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | load_dotenv()
25 |
26 |
27 | async def main():
28 | rpc = os.environ.get("MAINNET_RPC_ENDPOINT")
29 | private_key = os.environ.get("PRIVATE_KEY")
30 | kp = load_keypair(private_key)
31 | wallet = Wallet(kp)
32 | connection = AsyncClient(rpc)
33 | provider = Provider(connection, wallet)
34 | drift_client = DriftClient(
35 | provider.connection,
36 | provider.wallet,
37 | "mainnet",
38 | account_subscription=AccountSubscriptionConfig("cached"),
39 | )
40 | await drift_client.subscribe()
41 |
42 | usermap_config = UserMapConfig(drift_client, PollingConfig(frequency=2))
43 | usermap = UserMap(usermap_config)
44 | await usermap.subscribe()
45 | # make a copy of the usermap
46 | usermap_copy = list(usermap.user_map.keys())
47 |
48 | # Setup CSV output
49 | filename = f"pnl_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
50 | with open(filename, "w", newline="") as csvfile:
51 | writer = csv.writer(csvfile)
52 | writer.writerow(
53 | [
54 | "User",
55 | "Authority",
56 | "Realized PnL",
57 | "Unrealized PnL",
58 | "Total PnL",
59 | "Total Collateral",
60 | ]
61 | )
62 |
63 | # Calculate PnL
64 | for user_pubkey in tqdm(usermap_copy):
65 | try:
66 | user_pubkey = Pubkey.from_string(user_pubkey)
67 | user = DriftUser(
68 | drift_client,
69 | user_public_key=user_pubkey,
70 | account_subscription=AccountSubscriptionConfig("cached"),
71 | )
72 | authority = str(user.get_user_account().authority)
73 | realized_pnl = user.get_user_account().settled_perp_pnl
74 | unrealized_pnl = user.get_unrealized_pnl(with_funding=True)
75 | total_pnl = realized_pnl + unrealized_pnl
76 | collateral = user.get_total_collateral()
77 |
78 | writer.writerow(
79 | [
80 | str(user_pubkey),
81 | authority,
82 | f"{realized_pnl:.2f}",
83 | f"{unrealized_pnl:.2f}",
84 | f"{total_pnl:.2f}",
85 | f"{collateral:.2f}",
86 | ]
87 | )
88 | except Exception as e:
89 | logger.error(f"Error calculating PnL for {user_pubkey}: {e}")
90 |
91 | logger.info(f"CSV report generated: {filename}")
92 |
93 |
94 | if __name__ == "__main__":
95 | asyncio.run(main())
96 |
--------------------------------------------------------------------------------
/examples/get_borrows.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | from anchorpy.provider import Wallet
5 | from dotenv import load_dotenv
6 | from solana.rpc.async_api import AsyncClient
7 |
8 | from driftpy.accounts.get_accounts import get_spot_market_account
9 | from driftpy.drift_client import DriftClient
10 | from driftpy.math.spot_balance import (
11 | calculate_borrow_rate,
12 | calculate_deposit_rate,
13 | calculate_interest_rate,
14 | calculate_utilization,
15 | )
16 | from driftpy.math.spot_market import get_token_amount
17 | from driftpy.types import SpotBalanceType
18 |
19 |
20 | async def main():
21 | load_dotenv()
22 | url = os.getenv("RPC_URL")
23 | connection = AsyncClient(url)
24 | dc = DriftClient(connection, Wallet.dummy(), "mainnet")
25 | market = await get_spot_market_account(dc.program, 0) # USDC
26 | if market is None:
27 | raise Exception("No market found")
28 | token_deposit_amount = get_token_amount(
29 | market.deposit_balance,
30 | market,
31 | SpotBalanceType.Deposit(), # type: ignore
32 | )
33 |
34 | token_borrow_amount = get_token_amount(
35 | market.borrow_balance,
36 | market,
37 | SpotBalanceType.Borrow(), # type: ignore
38 | )
39 | print(f"token_deposit_amount: {(token_deposit_amount/10**market.decimals):,.2f}")
40 | print(f"token_borrow_amount: {(token_borrow_amount/10**market.decimals):,.2f}")
41 |
42 | borrow_rate = calculate_borrow_rate(market)
43 | deposit_rate = calculate_deposit_rate(market)
44 | utilization = calculate_utilization(market)
45 | interest_rate = calculate_interest_rate(market)
46 |
47 | precision = 10000
48 | print(f"borrow_rate: {borrow_rate/precision:.2f}%")
49 | print(f"deposit_rate: {deposit_rate/precision:.2f}%")
50 | print(f"utilization: {utilization/precision:.2f}%")
51 | print(f"interest_rate: {interest_rate/precision:.2f}%")
52 |
53 |
54 | if __name__ == "__main__":
55 | asyncio.run(main())
56 |
--------------------------------------------------------------------------------
/examples/get_deposits.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | from anchorpy.provider import Wallet
5 | from dotenv import load_dotenv
6 | from solana.rpc.async_api import AsyncClient
7 |
8 | from driftpy.account_subscription_config import AccountSubscriptionConfig
9 | from driftpy.constants.numeric_constants import SPOT_BALANCE_PRECISION
10 | from driftpy.drift_client import DriftClient
11 | from driftpy.market_map.market_map import MarketMap
12 | from driftpy.market_map.market_map_config import MarketMapConfig
13 | from driftpy.market_map.market_map_config import (
14 | WebsocketConfig as MarketMapWebsocketConfig,
15 | )
16 | from driftpy.pickle.vat import Vat
17 | from driftpy.types import MarketType, is_variant
18 | from driftpy.user_map.user_map import UserMap
19 | from driftpy.user_map.user_map_config import UserMapConfig, UserStatsMapConfig
20 | from driftpy.user_map.user_map_config import (
21 | WebsocketConfig as UserMapWebsocketConfig,
22 | )
23 | from driftpy.user_map.userstats_map import UserStatsMap
24 |
25 | load_dotenv()
26 |
27 |
28 | def get_deposits_by_authority(vat: Vat, market_index: int):
29 | deposits = {}
30 |
31 | for user in vat.users.values():
32 | for position in user.get_user_account().spot_positions:
33 | if (
34 | position.market_index == market_index
35 | and position.scaled_balance > 0
36 | and not is_variant(position.balance_type, "Borrow")
37 | ):
38 | authority = user.user_public_key
39 | balance = position.scaled_balance / SPOT_BALANCE_PRECISION
40 |
41 | if authority in deposits:
42 | deposits[authority] += balance
43 | else:
44 | deposits[authority] = balance
45 |
46 | return {
47 | "deposits": [
48 | {"authority": authority, "balance": balance}
49 | for authority, balance in sorted(
50 | deposits.items(), key=lambda x: x[1], reverse=True
51 | )
52 | ]
53 | }
54 |
55 |
56 | async def main():
57 | rpc_url = os.getenv("MAINNET_RPC_ENDPOINT")
58 | if not rpc_url:
59 | raise ValueError("MAINNET_RPC_ENDPOINT is not set")
60 |
61 | connection = AsyncClient(rpc_url)
62 | wallet = Wallet.dummy()
63 | dc = DriftClient(
64 | connection,
65 | wallet,
66 | "mainnet",
67 | account_subscription=AccountSubscriptionConfig("cached"),
68 | )
69 | perp_map = MarketMap(
70 | MarketMapConfig(
71 | dc.program,
72 | MarketType.Perp(), # type: ignore
73 | MarketMapWebsocketConfig(),
74 | dc.connection,
75 | )
76 | )
77 | spot_map = MarketMap(
78 | MarketMapConfig(
79 | dc.program,
80 | MarketType.Spot(), # type: ignore
81 | MarketMapWebsocketConfig(),
82 | dc.connection,
83 | )
84 | )
85 | user_map = UserMap(UserMapConfig(dc, UserMapWebsocketConfig()))
86 | stats_map = UserStatsMap(UserStatsMapConfig(dc))
87 | await asyncio.gather(
88 | asyncio.create_task(spot_map.subscribe()),
89 | asyncio.create_task(perp_map.subscribe()),
90 | asyncio.create_task(user_map.subscribe()),
91 | asyncio.create_task(stats_map.subscribe()),
92 | )
93 | print("Subscribed to Drift Client")
94 | await user_map.sync()
95 | print("Synced User Map")
96 |
97 | vat = Vat(
98 | dc,
99 | user_map,
100 | stats_map,
101 | spot_map,
102 | perp_map,
103 | )
104 | deposits = get_deposits_by_authority(vat, 0)
105 | return deposits
106 |
107 |
108 | if __name__ == "__main__":
109 | deposits = asyncio.run(main())
110 | print(deposits)
111 |
--------------------------------------------------------------------------------
/examples/get_funding.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | from anchorpy.provider import Wallet
5 | from dotenv import load_dotenv
6 | from solana.rpc.async_api import AsyncClient
7 |
8 | from driftpy.accounts.get_accounts import get_perp_market_account
9 | from driftpy.accounts.oracle import get_oracle_price_data_and_slot
10 | from driftpy.constants import FUNDING_RATE_PRECISION, QUOTE_PRECISION
11 | from driftpy.drift_client import DriftClient
12 | from driftpy.math.funding import (
13 | calculate_live_mark_twap,
14 | calculate_long_short_funding_and_live_twaps,
15 | )
16 |
17 |
18 | async def main():
19 | load_dotenv()
20 | url = os.getenv("RPC_URL")
21 | connection = AsyncClient(url)
22 | dc = DriftClient(connection, Wallet.dummy(), "mainnet")
23 | await dc.subscribe()
24 |
25 | market = await get_perp_market_account(dc.program, 0) # SOL-PERP
26 | if market is None:
27 | raise Exception("No market found")
28 |
29 | oracle_price = await get_oracle_price_data_and_slot(
30 | connection, market.amm.oracle, market.amm.oracle_source
31 | )
32 | oracle_price_data = oracle_price.data
33 |
34 | now = int(asyncio.get_event_loop().time())
35 | mark_price = market.amm.historical_oracle_data.last_oracle_price
36 |
37 | (
38 | mark_twap,
39 | oracle_twap,
40 | long_rate,
41 | short_rate,
42 | ) = await calculate_long_short_funding_and_live_twaps(
43 | market, oracle_price_data, mark_price, now
44 | )
45 |
46 | precision = FUNDING_RATE_PRECISION
47 | print(f"Long Funding Rate: {long_rate/precision}%")
48 | print(
49 | f"Last 24h Avg Funding Rate: {market.amm.last24h_avg_funding_rate/precision}%"
50 | )
51 | print(f"Last Funding Rate: {market.amm.last_funding_rate/precision}%")
52 | print(f"Last Funding Rate Long: {market.amm.last_funding_rate_long/precision}%")
53 | print(f"Last Funding Rate Short: {market.amm.last_funding_rate_short/precision}%")
54 |
55 | print(f"Oracle Price TWAP: ${oracle_twap/QUOTE_PRECISION:.2f}")
56 |
57 | live_mark_twap = calculate_live_mark_twap(
58 | market, oracle_price_data, mark_price, now
59 | )
60 | print(f"Live Mark TWAP: ${live_mark_twap/QUOTE_PRECISION:.2f}")
61 |
62 |
63 | if __name__ == "__main__":
64 | asyncio.run(main())
65 |
--------------------------------------------------------------------------------
/examples/get_vault_metrics.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | import dotenv
5 | from solana.rpc.async_api import AsyncClient
6 |
7 | from driftpy.vaults import VaultClient
8 |
9 | dotenv.load_dotenv()
10 |
11 | connection = AsyncClient(os.getenv("RPC_TRITON"))
12 |
13 |
14 | def format_vault_summary(analytics, top_n=5):
15 | """Format vault data for readable output"""
16 | output = []
17 | output.append(f"Total Vaults: {analytics['total_vaults']}")
18 | output.append(f"Total Value Locked: {analytics['total_deposits']:,.2f}")
19 | output.append(f"Total Unique Depositors: {analytics['total_depositors']}")
20 |
21 | output.append("\n----- Top Vaults by Deposits -----")
22 | for i, vault in enumerate(analytics["top_by_deposits"][:top_n], 1):
23 | output.append(f"{i}. {vault['name']}: {vault['true_net_deposits']:,.2f}")
24 |
25 | output.append("\n----- Top Vaults by Users -----")
26 | for i, vault in enumerate(analytics["top_by_users"][:top_n], 1):
27 | output.append(f"{i}. {vault['name']}: {vault['depositor_count']} users")
28 |
29 | return "\n".join(output)
30 |
31 |
32 | async def main():
33 | vault_client = await VaultClient(connection).initialize()
34 | print("Got vault client...")
35 | analytics = await vault_client.calculate_analytics()
36 | print("Got analytics...")
37 | depositors = await vault_client.get_all_depositors()
38 | print(f"Got {len(depositors)} depositors...")
39 | print("Top 5 depositors:")
40 | for i, dep in enumerate(depositors[:5], 1):
41 | print(f"{i}. {dep}")
42 |
43 | print(format_vault_summary(analytics))
44 | drift_boost = await vault_client.get_vault_by_name("SOL-NL-Neutral-Trade")
45 |
46 | if drift_boost:
47 | depositors = await vault_client.get_vault_depositors_with_stats(
48 | drift_boost.account.pubkey
49 | )
50 | print("\nTop 5 depositors:")
51 | for i, dep in enumerate(depositors[:5], 1):
52 | print(
53 | f"{i}. {dep['pubkey']}: {dep['shares']}, {dep['share_percentage']:.2f}%"
54 | )
55 |
56 |
57 | if __name__ == "__main__":
58 | asyncio.run(main())
59 |
--------------------------------------------------------------------------------
/examples/grpc_auction.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | from anchorpy.provider import Provider, Wallet
5 | from dotenv import load_dotenv
6 | from solana.rpc.async_api import AsyncClient
7 | from solders.keypair import Keypair
8 |
9 | from driftpy.account_subscription_config import AccountSubscriptionConfig
10 | from driftpy.auction_subscriber.grpc_auction_subscriber import GrpcAuctionSubscriber
11 |
12 | # Auction subscriber imports
13 | from driftpy.auction_subscriber.types import GrpcAuctionSubscriberConfig
14 |
15 | # Drift imports
16 | from driftpy.drift_client import DriftClient
17 | from driftpy.types import GrpcConfig, UserAccount
18 |
19 | load_dotenv()
20 |
21 |
22 | # Example callback
23 | def on_auction_account_update(user_account: UserAccount, pubkey, slot: int):
24 | """
25 | This function is called whenever an auctioned User account is updated.
26 | """
27 | print(
28 | f"[AUCTION UPDATE] Slot={slot} PublicKey={pubkey}, UserAccount={user_account}"
29 | )
30 |
31 |
32 | async def main():
33 | """
34 | Main entrypoint: create a gRPC-based Drift client, set up the GrpcAuctionSubscriber,
35 | and attach a callback that prints changes to auction accounts.
36 | """
37 |
38 | # 1) Load environment variables
39 | rpc_fqdn = os.environ.get("RPC_FQDN") # e.g. "grpcs://my-geyser-endpoint.com:443"
40 | x_token = os.environ.get("X_TOKEN") # your auth token
41 | private_key = os.environ.get("PRIVATE_KEY") # base58-encoded, e.g. "42Ab..."
42 | rpc_url = os.environ.get("RPC_TRITON") # normal Solana JSON-RPC for sending tx
43 |
44 | if not (rpc_fqdn and x_token and private_key and rpc_url):
45 | raise ValueError("RPC_FQDN, X_TOKEN, PRIVATE_KEY, and RPC_TRITON must be set")
46 |
47 | wallet = Wallet(Keypair.from_base58_string(private_key))
48 | connection = AsyncClient(rpc_url)
49 | provider = Provider(connection, wallet)
50 |
51 | drift_client = DriftClient(
52 | provider.connection,
53 | provider.wallet,
54 | "mainnet",
55 | account_subscription=AccountSubscriptionConfig(
56 | "grpc",
57 | grpc_config=GrpcConfig(
58 | endpoint=rpc_fqdn,
59 | token=x_token,
60 | ),
61 | ),
62 | )
63 |
64 | await drift_client.subscribe()
65 |
66 | auction_subscriber_config = GrpcAuctionSubscriberConfig(
67 | drift_client=drift_client,
68 | grpc_config=GrpcConfig(endpoint=rpc_fqdn, token=x_token),
69 | commitment=provider.connection.commitment, # or "confirmed"
70 | )
71 | auction_subscriber = GrpcAuctionSubscriber(auction_subscriber_config)
72 |
73 | auction_subscriber.event_emitter.on_account_update += on_auction_account_update
74 | await auction_subscriber.subscribe()
75 | print("AuctionSubscriber is now watching for changes to in-auction User accounts.")
76 |
77 | try:
78 | while True:
79 | await asyncio.sleep(10)
80 | except KeyboardInterrupt:
81 | print("Unsubscribing from auction accounts...")
82 | auction_subscriber.unsubscribe()
83 |
84 |
85 | if __name__ == "__main__":
86 | asyncio.run(main())
87 | print("Done.")
88 |
--------------------------------------------------------------------------------
/examples/grpc_client.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | from anchorpy.provider import Provider, Wallet
5 | from dotenv import load_dotenv
6 | from solana.rpc.async_api import AsyncClient
7 | from solana.rpc.commitment import Commitment
8 | from solders.keypair import Keypair
9 |
10 | from driftpy.drift_client import AccountSubscriptionConfig, DriftClient
11 | from driftpy.types import GrpcConfig
12 |
13 | load_dotenv()
14 |
15 | RED = "\033[91m"
16 | GREEN = "\033[92m"
17 | RESET = "\033[0m"
18 |
19 | CLEAR_SCREEN = "\033c"
20 |
21 |
22 | async def watch_drift_markets():
23 | rpc_fqdn = os.environ.get("RPC_FQDN")
24 | x_token = os.environ.get("X_TOKEN")
25 | private_key = os.environ.get("PRIVATE_KEY")
26 | rpc_url = os.environ.get("RPC_TRITON")
27 |
28 | if not (rpc_fqdn and x_token and private_key and rpc_url):
29 | raise ValueError("RPC_FQDN, X_TOKEN, PRIVATE_KEY, and RPC_TRITON must be set")
30 |
31 | wallet = Wallet(Keypair.from_base58_string(private_key))
32 | connection = AsyncClient(rpc_url)
33 | provider = Provider(connection, wallet)
34 |
35 | drift_client = DriftClient(
36 | provider.connection,
37 | provider.wallet,
38 | "mainnet",
39 | account_subscription=AccountSubscriptionConfig(
40 | "grpc",
41 | grpc_config=GrpcConfig(
42 | endpoint=rpc_fqdn,
43 | token=x_token,
44 | commitment=Commitment("confirmed"),
45 | ),
46 | ),
47 | )
48 |
49 | await drift_client.subscribe()
50 | print("Subscribed via gRPC. Listening for market updates...")
51 |
52 | previous_prices = {}
53 |
54 | while True:
55 | print(CLEAR_SCREEN, end="")
56 |
57 | perp_markets = drift_client.get_perp_market_accounts()
58 |
59 | if not perp_markets:
60 | print(f"{RED}No perp markets found (yet){RESET}")
61 | else:
62 | print("Drift Perp Markets (gRPC subscription)\n")
63 | perp_markets.sort(key=lambda x: x.market_index)
64 | for market in perp_markets[:20]:
65 | market_index = market.market_index
66 | last_price = market.amm.historical_oracle_data.last_oracle_price / 1e6
67 |
68 | if market_index in previous_prices:
69 | old_price = previous_prices[market_index]
70 | if last_price > old_price:
71 | color = GREEN
72 | elif last_price < old_price:
73 | color = RED
74 | else:
75 | color = RESET
76 | else:
77 | color = RESET
78 |
79 | print(
80 | f"Market Index: {market_index} | "
81 | f"Price: {color}${last_price:.4f}{RESET}"
82 | )
83 |
84 | previous_prices[market_index] = last_price
85 |
86 | await asyncio.sleep(1)
87 |
88 |
89 | if __name__ == "__main__":
90 | asyncio.run(watch_drift_markets())
91 |
--------------------------------------------------------------------------------
/examples/high_leverage_users.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | from anchorpy import Wallet
5 | from dotenv import load_dotenv
6 | from solana.rpc.async_api import AsyncClient
7 |
8 | from driftpy.drift_client import DriftClient
9 | from driftpy.user_map.user_map import UserMap
10 | from driftpy.user_map.user_map_config import UserMapConfig
11 | from driftpy.user_map.user_map_config import (
12 | WebsocketConfig as UserMapWebsocketConfig,
13 | )
14 |
15 |
16 | async def main():
17 | load_dotenv()
18 | url = os.getenv("RPC_TRITON")
19 | connection = AsyncClient(url)
20 | dc = DriftClient(
21 | connection,
22 | Wallet.dummy(),
23 | "mainnet",
24 | )
25 | await dc.subscribe()
26 | user = UserMap(UserMapConfig(dc, UserMapWebsocketConfig(), include_idle=False))
27 | await user.subscribe()
28 |
29 | high_leverage_users = []
30 | keys = []
31 | for key, user in user.user_map.items():
32 | if user.is_high_leverage_mode():
33 | high_leverage_users.append(user)
34 | keys.append(key)
35 |
36 | return high_leverage_users, keys
37 |
38 |
39 | if __name__ == "__main__":
40 | high_leverage_users, keys = asyncio.run(main())
41 | keys.sort()
42 | print(f"Number of high leverage users: {len(keys)}")
43 | # with open("high_leverage_users.txt", "w") as f:
44 | # for key in keys:
45 | # f.write(f"{key}\n")
46 |
--------------------------------------------------------------------------------
/examples/pmm_users.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | from anchorpy import Wallet
5 | from dotenv import load_dotenv
6 | from solana.rpc.async_api import AsyncClient
7 |
8 | from driftpy.addresses import get_protected_maker_mode_config_public_key
9 | from driftpy.drift_client import DriftClient
10 | from driftpy.math.user_status import is_user_protected_maker
11 | from driftpy.user_map.user_map import UserMap
12 | from driftpy.user_map.user_map_config import UserMapConfig
13 | from driftpy.user_map.user_map_config import (
14 | WebsocketConfig as UserMapWebsocketConfig,
15 | )
16 |
17 |
18 | async def main():
19 | load_dotenv()
20 | url = os.getenv("RPC_TRITON")
21 | connection = AsyncClient(url)
22 | dc = DriftClient(
23 | connection,
24 | Wallet.dummy(),
25 | "mainnet",
26 | )
27 | await dc.subscribe()
28 | user = UserMap(UserMapConfig(dc, UserMapWebsocketConfig(), include_idle=True))
29 | await user.subscribe()
30 |
31 | pmm_users = []
32 | keys = []
33 | print(get_protected_maker_mode_config_public_key(dc.program_id))
34 | for key, user in user.user_map.items():
35 | if is_user_protected_maker(user.get_user_account()):
36 | pmm_users.append(user)
37 | keys.append(key)
38 |
39 | return pmm_users, keys
40 |
41 |
42 | if __name__ == "__main__":
43 | pmm_users, keys = asyncio.run(main())
44 | keys.sort()
45 | print(f"Number of protected maker users: {len(keys)}")
46 |
--------------------------------------------------------------------------------
/examples/protected_order.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 |
4 | from anchorpy.provider import Wallet
5 | from dotenv import load_dotenv
6 | from solana.rpc.async_api import AsyncClient
7 | from solders.keypair import Keypair
8 |
9 | from driftpy.addresses import get_protected_maker_mode_config_public_key
10 | from driftpy.drift_client import DriftClient
11 | from driftpy.math.user_status import is_user_protected_maker
12 |
13 | load_dotenv()
14 |
15 |
16 | async def is_protected_maker():
17 | rpc = os.environ.get("RPC_TRITON")
18 | kp = Keypair.from_base58_string(os.environ.get("PRIVATE_KEY", ""))
19 | wallet = Wallet(kp)
20 | connection = AsyncClient(rpc)
21 | drift_client = DriftClient(
22 | connection=connection,
23 | wallet=wallet,
24 | env="mainnet",
25 | )
26 |
27 | await drift_client.subscribe()
28 | config_account = await drift_client.program.account[
29 | "ProtectedMakerModeConfig"
30 | ].fetch(get_protected_maker_mode_config_public_key(drift_client.program_id))
31 | print(config_account)
32 | user = drift_client.get_user()
33 | print(user.get_open_orders())
34 | print(get_protected_maker_mode_config_public_key(drift_client.program_id))
35 | print(
36 | "Is user protected maker: ",
37 | is_user_protected_maker(user.get_user_account()),
38 | )
39 | print(await drift_client.update_user_protected_maker_orders(0, True))
40 | print(
41 | "Is user protected maker: ",
42 | is_user_protected_maker(user.get_user_account()),
43 | )
44 | await drift_client.unsubscribe()
45 |
46 |
47 | if __name__ == "__main__":
48 | asyncio.run(is_protected_maker())
49 |
--------------------------------------------------------------------------------
/examples/readme.md:
--------------------------------------------------------------------------------
1 | ## Staking in the Insurance Fund
2 |
3 | How to run
4 | ```bash
5 | python if_stake.py
6 | --keypath ../keypairs/x19.json # this should be the keypair path
7 | --amount 100 # USD amount
8 | --market 0 # USD spot market
9 | --operation add
10 | --env mainnet
11 | ```
--------------------------------------------------------------------------------
/examples/settle_pnl.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | import os
4 |
5 | from anchorpy import Wallet
6 | from dotenv import load_dotenv
7 | from driftpy.drift_client import DriftClient
8 | from solana.rpc.async_api import AsyncClient
9 | from solders.keypair import Keypair
10 | from solders.pubkey import Pubkey
11 |
12 |
13 | load_dotenv()
14 |
15 |
16 | async def setup_drift_client() -> DriftClient:
17 | connection = AsyncClient(os.getenv("RPC_URL"))
18 | kp = Keypair.from_base58_string(os.getenv("PRIVATE_KEY"))
19 | wallet = Wallet(kp)
20 | drift_client = DriftClient(connection=connection, wallet=wallet, env="mainnet")
21 | await drift_client.subscribe()
22 | await drift_client.account_subscriber.fetch()
23 | print("drift_client subscribe done")
24 | return drift_client
25 |
26 |
27 | async def settle_pnl_once(
28 | drift_client: DriftClient, user_public_key: Pubkey, market_index: int
29 | ):
30 | user = drift_client.get_user()
31 | account_info = user.get_user_account_and_slot().data
32 | tx_hash = await drift_client.settle_pnl(user_public_key, account_info, market_index)
33 | print(f"Settled PNL for market {market_index}, tx: {tx_hash}")
34 |
35 |
36 | async def main():
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument("--subaccount_id", type=int, required=True)
39 | parser.add_argument("--market_index", type=int, required=True)
40 | args = parser.parse_args()
41 |
42 | try:
43 | drift_client = await setup_drift_client()
44 | user_public_key = drift_client.get_user_account_public_key(args.subaccount_id)
45 | await asyncio.wait_for(
46 | settle_pnl_once(drift_client, user_public_key, args.market_index),
47 | timeout=540,
48 | )
49 | except:
50 | raise
51 |
52 |
53 | if __name__ == "__main__":
54 | asyncio.run(main())
55 |
--------------------------------------------------------------------------------
/examples/spot_market_trade.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import os
4 |
5 | from anchorpy.provider import Provider, Wallet
6 | from dotenv import load_dotenv
7 | from solana.rpc.async_api import AsyncClient
8 |
9 | from driftpy.constants.spot_markets import mainnet_spot_market_configs
10 | from driftpy.drift_client import DriftClient
11 | from driftpy.keypair import load_keypair
12 | from driftpy.types import TxParams
13 |
14 | logging.basicConfig(level=logging.INFO)
15 | logger = logging.getLogger(__name__)
16 |
17 | load_dotenv()
18 |
19 |
20 | def get_market_by_symbol(symbol: str):
21 | for market in mainnet_spot_market_configs:
22 | if market.symbol == symbol:
23 | return market
24 | raise Exception(f"Market {symbol} not found")
25 |
26 |
27 | async def make_spot_trade():
28 | rpc = os.environ.get("RPC_TRITON")
29 | secret = os.environ.get("PRIVATE_KEY")
30 | kp = load_keypair(secret)
31 | wallet = Wallet(kp)
32 | logger.info(f"Using wallet: {wallet.public_key}")
33 |
34 | connection = AsyncClient(rpc)
35 | provider = Provider(connection, wallet)
36 | drift_client = DriftClient(
37 | provider.connection,
38 | provider.wallet,
39 | env="mainnet",
40 | tx_params=TxParams(compute_units_price=85_000, compute_units=1_400_000),
41 | )
42 | await drift_client.subscribe()
43 |
44 | logger.info("Drift client subscribed")
45 |
46 | market_symbol_1 = "JLP"
47 | market_symbol_2 = "USDC"
48 |
49 | in_decimals_result = drift_client.get_spot_market_account(
50 | get_market_by_symbol(market_symbol_1).market_index
51 | )
52 | if not in_decimals_result:
53 | logger.error("USDS market not found")
54 | raise Exception("Market not found")
55 |
56 | in_decimals = in_decimals_result.decimals
57 | logger.info(f"{market_symbol_1} decimals: {in_decimals}")
58 |
59 | swap_amount = int(1 * 10**in_decimals)
60 | logger.info(f"Swapping {swap_amount} {market_symbol_1} to {market_symbol_2}")
61 |
62 | swap_ixs, swap_lookups = await drift_client.get_jupiter_swap_ix_v6(
63 | out_market_idx=get_market_by_symbol(market_symbol_2).market_index,
64 | in_market_idx=get_market_by_symbol(market_symbol_1).market_index,
65 | amount=swap_amount,
66 | swap_mode="ExactIn",
67 | only_direct_routes=True,
68 | max_accounts=20,
69 | )
70 | await drift_client.send_ixs(
71 | ixs=swap_ixs,
72 | lookup_tables=swap_lookups,
73 | )
74 | logger.info("Swap complete")
75 |
76 |
77 | if __name__ == "__main__":
78 | asyncio.run(make_spot_trade())
79 |
--------------------------------------------------------------------------------
/examples/view.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | import time
4 | from pprint import pprint
5 |
6 | import dotenv
7 | from anchorpy import Wallet
8 | from solana.rpc.async_api import AsyncClient
9 | from solders.pubkey import Pubkey
10 |
11 | from driftpy.account_subscription_config import AccountSubscriptionConfig
12 | from driftpy.constants.numeric_constants import (
13 | MARGIN_PRECISION,
14 | QUOTE_PRECISION,
15 | )
16 | from driftpy.decode.utils import decode_name
17 | from driftpy.drift_client import DriftClient
18 | from driftpy.drift_user import DriftUser
19 | from driftpy.keypair import load_keypair
20 | from driftpy.math.conversion import convert_to_number
21 | from driftpy.math.perp_position import is_available
22 |
23 | dotenv.load_dotenv()
24 |
25 |
26 | async def main():
27 | s = time.time()
28 | kp = load_keypair(os.getenv("PRIVATE_KEY"))
29 | wallet = Wallet(kp)
30 | connection = AsyncClient(os.getenv("RPC_TRITON"))
31 | dc = DriftClient(
32 | connection,
33 | wallet,
34 | account_subscription=AccountSubscriptionConfig("websocket"),
35 | )
36 | await dc.subscribe()
37 | drift_user = dc.get_user()
38 | user = drift_user.get_user_account()
39 | print("\n=== User Info ===")
40 | print(f"Subaccount name: {decode_name(user.name)}")
41 |
42 | spot_collateral = drift_user.get_spot_market_asset_value(
43 | None,
44 | include_open_orders=True,
45 | )
46 | print("\n=== Collateral & PnL ===")
47 | print(f"Spot collateral: ${spot_collateral / QUOTE_PRECISION:,.2f}")
48 |
49 | pnl = drift_user.get_unrealized_pnl(False)
50 | print(f"Unrealized PnL: ${pnl / QUOTE_PRECISION:,.2f}")
51 |
52 | total_collateral = drift_user.get_total_collateral()
53 | print(f"Total collateral: ${total_collateral / QUOTE_PRECISION:,.2f}")
54 |
55 | perp_liability = drift_user.get_total_perp_position_liability()
56 | spot_liability = drift_user.get_spot_market_liability_value()
57 | print("\n=== Liabilities ===")
58 | print(f"Perp liability: ${perp_liability / QUOTE_PRECISION:,.2f}")
59 | print(f"Spot liability: ${spot_liability / QUOTE_PRECISION:,.2f}")
60 |
61 | perp_market = dc.get_perp_market_account(0)
62 | oracle = convert_to_number(
63 | dc.get_oracle_price_data_for_perp_market(0).price, QUOTE_PRECISION
64 | )
65 | print("\n=== Market Info ===")
66 | print(f"Oracle price: ${oracle:,.2f}")
67 |
68 | init_leverage = MARGIN_PRECISION / perp_market.margin_ratio_initial
69 | maint_leverage = MARGIN_PRECISION / perp_market.margin_ratio_maintenance
70 | print(f"Initial leverage: {init_leverage:.2f}x")
71 | print(f"Maintenance leverage: {maint_leverage:.2f}x")
72 |
73 | liq_price = drift_user.get_perp_liq_price(0)
74 | print(f"Liquidation price: ${liq_price:,.2f}")
75 |
76 | total_liability = drift_user.get_margin_requirement(None)
77 | total_asset_value = drift_user.get_total_collateral()
78 | print("\n=== Risk Metrics ===")
79 | print(f"Total liability: ${total_liability / QUOTE_PRECISION:,.2f}")
80 | print(f"Total asset value: ${total_asset_value / QUOTE_PRECISION:,.2f}")
81 | print(f"Current leverage: {(drift_user.get_leverage()) / 10_000:.2f}x")
82 |
83 | user = drift_user.get_user_account()
84 | print("\n=== Perp Positions ===")
85 | for position in user.perp_positions:
86 | if not is_available(position):
87 | pprint(position)
88 |
89 | print(f"\nTime taken: {time.time() - s:.2f}s")
90 |
91 | orders = drift_user.get_open_orders()
92 | print("\n=== Orders ===")
93 | pprint(orders, indent=2, width=80)
94 |
95 | print("\n=== Health Components ===")
96 | pprint(drift_user.get_health_components(), indent=2, width=80)
97 |
98 | # Another user
99 | drift_user2 = DriftUser(
100 | drift_client=dc,
101 | user_public_key=Pubkey.from_string(os.getenv("RANDOM_USER_ACCOUNT_PUBKEY")),
102 | )
103 | await drift_user2.subscribe()
104 |
105 | print("\n=== Perp Positions (User 2) ===")
106 | for position in drift_user2.get_user_account().perp_positions:
107 | if position.base_asset_amount != 0:
108 | pprint(position)
109 |
110 | pnl = drift_user2.get_net_usd_value()
111 | print(f"Net USD Value: ${pnl / QUOTE_PRECISION:,.2f}")
112 |
113 |
114 | if __name__ == "__main__":
115 | asyncio.run(main())
116 |
--------------------------------------------------------------------------------
/examples/ws_simple.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import os
4 |
5 | import dotenv
6 | from anchorpy.provider import Wallet
7 | from solana.rpc.async_api import AsyncClient
8 | from solders.keypair import Keypair
9 |
10 | from driftpy.accounts.ws.drift_client import WebsocketDriftClientAccountSubscriber
11 | from driftpy.drift_client import DriftClient
12 | from driftpy.keypair import load_keypair
13 | from driftpy.types import (
14 | TxParams,
15 | )
16 |
17 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | async def get_all_spot_indexes(rpc_url: str):
22 | throwaway_drift_client = DriftClient(
23 | connection=AsyncClient(rpc_url),
24 | wallet=Wallet(Keypair()),
25 | env="mainnet",
26 | spot_market_indexes=[],
27 | )
28 | spot_markets = await throwaway_drift_client.program.account["SpotMarket"].all()
29 | await throwaway_drift_client.unsubscribe()
30 | await throwaway_drift_client.connection.close()
31 | return [market.account.market_index for market in spot_markets]
32 |
33 |
34 | async def get_all_perp_indexes(rpc_url: str):
35 | throwaway_drift_client = DriftClient(
36 | connection=AsyncClient(rpc_url),
37 | wallet=Wallet(Keypair()),
38 | env="mainnet",
39 | perp_market_indexes=[],
40 | )
41 | perp_markets = await throwaway_drift_client.program.account["PerpMarket"].all()
42 | await throwaway_drift_client.unsubscribe()
43 | await throwaway_drift_client.connection.close()
44 | return [market.account.market_index for market in perp_markets]
45 |
46 |
47 | async def main():
48 | logger.info("Starting...")
49 | dotenv.load_dotenv()
50 | rpc_url = os.getenv("RPC_TRITON")
51 | private_key = os.getenv("PRIVATE_KEY")
52 | if not rpc_url or not private_key:
53 | raise Exception("Missing env vars")
54 | kp = load_keypair(private_key)
55 |
56 | drift_client = DriftClient(
57 | connection=AsyncClient(rpc_url),
58 | wallet=Wallet(kp),
59 | env="mainnet",
60 | tx_params=TxParams(700_000, 10_000),
61 | )
62 | await drift_client.subscribe()
63 | perp_markets = drift_client.get_perp_market_account(65)
64 | if perp_markets is None:
65 | raise Exception("No perp markets found")
66 | print(perp_markets)
67 |
68 | try:
69 | if not isinstance(
70 | drift_client.account_subscriber, WebsocketDriftClientAccountSubscriber
71 | ):
72 | raise Exception("Account subscriber is not a WebsocketAccountSubscriber")
73 | while True:
74 | print(perp_markets.amm.historical_oracle_data.last_oracle_price / 1e6)
75 | await asyncio.sleep(5)
76 | except KeyboardInterrupt:
77 | logger.info("Interrupted by user. Exiting loop...")
78 | finally:
79 | await drift_client.unsubscribe()
80 | await drift_client.connection.close()
81 | logger.info("Unsubscribed from Drift client.")
82 |
83 |
84 | if __name__ == "__main__":
85 | asyncio.run(main())
86 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: DriftPy
2 | theme:
3 | name: material
4 | icon:
5 | logo: material/chart-line
6 | favicon: img/drift.png
7 | palette:
8 | - scheme: default
9 | primary: deep-purple
10 | toggle:
11 | icon: material/toggle-switch-off-outline
12 | name: Switch to dark mode
13 | - scheme: slate
14 | toggle:
15 | icon: material/toggle-switch
16 | name: Switch to light mode
17 | markdown_extensions:
18 | - pymdownx.highlight
19 | - pymdownx.superfences
20 | - admonition
21 | - pymdownx.snippets
22 | - meta
23 | - pymdownx.tabbed:
24 | alternate_style: true
25 | repo_url: https://github.com/drift-labs/driftpy
26 | repo_name: drift-labs/driftpy
27 | site_url: https://drift-labs.github.io/driftpy/
28 | plugins:
29 | - mkdocstrings:
30 | handlers:
31 | python:
32 | selection:
33 | filters:
34 | - "!^_" # exlude all members starting with _
35 | - "^__init__$" # but always include __init__ modules and methods
36 | - "!^T$"
37 | - "!^get_clearing_house_state_account_public_key_and_nonce$"
38 | rendering:
39 | show_root_heading: true
40 | show_root_full_path: false
41 | show_if_no_docstring: true
42 | - search
43 | nav:
44 | - index.md
45 | - clearing_house.md
46 | - clearing_house_user.md
47 | - accounts.md
48 | - addresses.md
49 | extra_css:
50 | - css/mkdocstrings.css
51 | - css/custom.css
52 |
--------------------------------------------------------------------------------
/poetry.toml:
--------------------------------------------------------------------------------
1 | [virtualenvs]
2 | in-project = true
3 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "driftpy"
3 | version = "0.8.56"
4 | description = "A Python client for the Drift DEX"
5 | authors = [
6 | "x19 ",
7 | "bigz ",
8 | "frank ",
9 | "sina ",
10 | ]
11 | license = "MIT"
12 | readme = "README.md"
13 | homepage = "https://github.com/drift-labs/driftpy"
14 | documentation = "https://drift-labs.github.io/driftpy/"
15 |
16 | [tool.poetry.dependencies]
17 | python = "^3.10"
18 | anchorpy = "0.21.0"
19 | solana = "^0.36"
20 | requests = "^2.28.1"
21 | pythclient = "0.2.1"
22 | aiodns = "3.0.0"
23 | aiohttp = "^3.9.1"
24 | aiosignal = "1.3.1"
25 | anchorpy-core = "0.2.0"
26 | anyio = "4.4.0"
27 | apischema = "0.17.5"
28 | async-timeout = "^4.0.2"
29 | attrs = "22.2.0"
30 | backoff = "2.2.1"
31 | base58 = "2.1.1"
32 | based58 = "0.1.1"
33 | borsh-construct = "0.1.0"
34 | cachetools = "5.3"
35 | certifi = "2022.12.7"
36 | cffi = "1.15.1"
37 | charset-normalizer = "2.1.1"
38 | construct = "2.10.68"
39 | construct-typing = "0.5.3"
40 | dnspython = "2.2.1"
41 | exceptiongroup = "1.0.4"
42 | h11 = "0.14.0"
43 | httpcore = "1.0.7"
44 | httpx = "0.28.1"
45 | idna = "3.4"
46 | iniconfig = "1.1.1"
47 | jsonalias = "0.1.1"
48 | jsonrpcclient = "4.0.3"
49 | jsonrpcserver = "5.0.9"
50 | jsonschema = "4.18.0"
51 | loguru = "^0.7.0"
52 | mccabe = "0.7.0"
53 | more-itertools = "8.14.0"
54 | oslash = "0.6.3"
55 | packaging = "23.1"
56 | psutil = "5.9.4"
57 | py = "1.11.0"
58 | pycares = "4.3.0"
59 | pycodestyle = "2.10.0"
60 | pycparser = "2.21"
61 | pyflakes = "3.0.1"
62 | pyheck = "0.1.5"
63 | pyrsistent = "0.19.2"
64 | rfc3986 = "1.5.0"
65 | sniffio = "1.3.0"
66 | solders = ">=0.23.0,<0.27.0"
67 | sumtypes = "0.1a6"
68 | toml = "0.10.2"
69 | tomli = "2.0.1"
70 | toolz = "0.11.2"
71 | types-cachetools = "4.2.10"
72 | typing-extensions = "^4.4.0"
73 | urllib3 = "1.26.13"
74 | websockets = "13.0"
75 | yarl = "1.9.4"
76 | zstandard = "0.18.0"
77 | deprecated = "^1.2.14"
78 | events = "^0.5"
79 | numpy = "^1.26.2"
80 | grpcio = "1.68.1"
81 | protobuf = "5.29.2"
82 | pynacl = "^1.5.0"
83 | tqdm = "^4.67.1"
84 |
85 | [tool.poetry.group.dev.dependencies]
86 | pytest = "^7.4.4"
87 | flake8 = "6.0.0"
88 | black = "24.4.2"
89 | pytest-asyncio = "0.21.0"
90 | mkdocs = "^1.3.0"
91 | mkdocstrings = "^0.17.0"
92 | mkdocs-material = "^8.1.8"
93 | bump2version = "^1.0.1"
94 | autopep8 = "^2.0.4"
95 | mypy = "^1.7.0"
96 | python-dotenv = "^1.0.0"
97 | ruff = "^0.8.4"
98 | # drift-jit-proxy = ">=0.1.6"
99 | pytest-xprocess = "0.18.1"
100 | types-requests = "^2.28.9"
101 | jinja2 = "^3.1.2"
102 |
103 | [build-system]
104 | requires = ["poetry-core>=1.0.0"]
105 | build-backend = "poetry.core.masonry.api"
106 |
107 | [tool.pytest.ini_options]
108 | asyncio_mode = "strict"
109 |
110 | [tool.ruff]
111 | exclude = [
112 | ".git",
113 | "__pycache__",
114 | "docs/source/conf.py",
115 | "old",
116 | "build",
117 | "dist",
118 | "**/geyser_codegen/**",
119 | ]
120 |
121 | [tool.ruff.lint.pycodestyle]
122 | max-line-length = 88
123 |
124 | [tool.pyright]
125 | reportMissingModuleSource = false
126 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiodns==3.0.0
2 | aiohttp==3.8.3
3 | aiosignal==1.3.1
4 | anchorpy==0.20.1
5 | anchorpy-core==0.2.0
6 | anyio==3.6.2
7 | apischema==0.17.5
8 | async-timeout==4.0.2
9 | attrs==22.1.0
10 | backoff==2.2.1
11 | base58==2.1.1
12 | based58==0.1.1
13 | borsh-construct==0.1.0
14 | cachetools==4.2.4
15 | certifi==2022.12.7
16 | cffi==1.15.1
17 | charset-normalizer==2.1.1
18 | construct==2.10.68
19 | construct-typing==0.5.3
20 | dnspython==2.2.1
21 | exceptiongroup==1.0.4
22 | flake8==6.0.0
23 | frozenlist==1.3.3
24 | h11==0.14.0
25 | httpcore==0.16.3
26 | httpx==0.23.1
27 | idna==3.4
28 | iniconfig==1.1.1
29 | jsonalias==0.1.1
30 | jsonrpcclient==4.0.2
31 | jsonrpcserver==5.0.9
32 | jsonschema==4.17.3
33 | loguru==0.6.0
34 | mccabe==0.7.0
35 | more-itertools==8.14.0
36 | multidict==6.0.3
37 | OSlash==0.6.3
38 | packaging==22.0
39 | pluggy==1.0.0
40 | psutil==5.9.4
41 | py==1.11.0
42 | pycares==4.3.0
43 | pycodestyle==2.10.0
44 | pycparser==2.21
45 | pyflakes==3.0.1
46 | pyheck==0.1.5
47 | pyrsistent==0.19.2
48 | pytest==6.2.5
49 | pytest-asyncio==0.17.2
50 | pytest-xprocess==0.18.1
51 | pythclient==0.2.1
52 | requests==2.28.1
53 | rfc3986==1.5.0
54 | sniffio==1.3.0
55 | solana==0.36.0
56 | solders==0.21.0
57 | sumtypes==0.1a6
58 | toml==0.10.2
59 | tomli==2.0.1
60 | toolz==0.11.2
61 | types-cachetools==4.2.10
62 | typing_extensions==4.4.0
63 | urllib3==1.26.13
64 | websockets==10.4
65 | yarl==1.8.2
66 | zstandard==0.17.0
67 |
--------------------------------------------------------------------------------
/scripts/bump.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 |
4 |
5 | def bump_version(version):
6 | major, minor, patch = map(int, version.split("."))
7 | patch += 1
8 | return f"{major}.{minor}.{patch}"
9 |
10 |
11 | def update_file(file_path, pattern, replacement):
12 | with open(file_path, "r") as file:
13 | content = file.read()
14 |
15 | updated_content = re.sub(pattern, replacement, content)
16 |
17 | with open(file_path, "w") as file:
18 | file.write(updated_content)
19 |
20 |
21 | def main():
22 | script_dir = os.path.dirname(os.path.abspath(__file__))
23 | project_root = os.path.dirname(script_dir)
24 |
25 | pyproject_path = os.path.join(project_root, "pyproject.toml")
26 | pyproject_pattern = r'(version\s*=\s*["\'])(\d+\.\d+\.\d+)(["\'])'
27 |
28 | init_path = os.path.join(project_root, "src", "driftpy", "__init__.py")
29 | init_pattern = r'(__version__\s*=\s*["\'])(\d+\.\d+\.\d+)(["\'])'
30 |
31 | bumpversion_path = os.path.join(project_root, ".bumpversion.cfg")
32 | bumpversion_pattern = r"(current_version\s*=\s*)(\d+\.\d+\.\d+)"
33 |
34 | with open(pyproject_path, "r") as file:
35 | content = file.read()
36 | match = re.search(pyproject_pattern, content)
37 | if match:
38 | current_version = match.group(2)
39 | else:
40 | print("Couldn't find version in pyproject.toml")
41 | return
42 |
43 | new_version = bump_version(current_version)
44 |
45 | update_file(pyproject_path, pyproject_pattern, rf"\g<1>{new_version}\g<3>")
46 | update_file(init_path, init_pattern, rf"\g<1>{new_version}\g<3>")
47 | update_file(bumpversion_path, bumpversion_pattern, rf"\g<1>{new_version}")
48 |
49 | print(f"Version bumped from {current_version} to {new_version}")
50 |
51 |
52 | if __name__ == "__main__":
53 | main()
54 |
--------------------------------------------------------------------------------
/scripts/ci.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | echo "Running tests:"
5 | pytest -v -s -x tests/ci/*.py
6 | pytest -v -s tests/math/*.py
7 |
8 | exit_code=$?
9 |
10 | if [ $exit_code -ne 0 ]; then
11 | echo "Tests failed with exit code $exit_code"
12 | exit $exit_code
13 | fi
14 |
15 | echo "All tests passed successfully"
--------------------------------------------------------------------------------
/scripts/decode.sh:
--------------------------------------------------------------------------------
1 | pytest -v -s tests/decode/*.py
--------------------------------------------------------------------------------
/scripts/dlob.sh:
--------------------------------------------------------------------------------
1 | pytest -v -s -x tests/dlob/dlob.py
--------------------------------------------------------------------------------
/scripts/generate_constants.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | import os
4 |
5 | import dotenv
6 | from anchorpy import Wallet
7 | from solana.rpc.async_api import AsyncClient
8 |
9 | from driftpy.account_subscription_config import AccountSubscriptionConfig
10 | from driftpy.constants.perp_markets import PerpMarketConfig
11 | from driftpy.constants.spot_markets import SpotMarketConfig
12 | from driftpy.drift_client import DriftClient
13 |
14 | dotenv.load_dotenv()
15 |
16 |
17 | def decode_name(name) -> str:
18 | return bytes(name).decode("utf-8").strip()
19 |
20 |
21 | async def generate_spot_configs(drift_client: DriftClient) -> str:
22 | spot_markets = sorted(
23 | drift_client.get_spot_market_accounts(), key=lambda market: market.market_index
24 | )
25 |
26 | configs = []
27 | for market in spot_markets:
28 | config = SpotMarketConfig(
29 | symbol=decode_name(market.name),
30 | market_index=market.market_index,
31 | oracle=market.oracle,
32 | oracle_source=market.oracle_source,
33 | mint=market.mint,
34 | )
35 | configs.append(config)
36 |
37 | output = """
38 | mainnet_spot_market_configs: list[SpotMarketConfig] = ["""
39 |
40 | for config in configs:
41 | output += f"""
42 | SpotMarketConfig(
43 | symbol="{config.symbol}",
44 | market_index={config.market_index},
45 | oracle=Pubkey.from_string("{str(config.oracle)}"),
46 | oracle_source=OracleSource.{config.oracle_source.__class__.__name__}(), # type: ignore
47 | mint=Pubkey.from_string("{str(config.mint)}"),
48 | ),"""
49 |
50 | output += "\n]\n"
51 | return output
52 |
53 |
54 | async def generate_perp_configs(drift_client: DriftClient) -> str:
55 | perp_markets = sorted(
56 | drift_client.get_perp_market_accounts(), key=lambda market: market.market_index
57 | )
58 |
59 | configs = []
60 | for market in perp_markets:
61 | config = PerpMarketConfig(
62 | symbol=decode_name(market.name),
63 | base_asset_symbol="-".join(decode_name(market.name).split("-")[:-1]),
64 | market_index=market.market_index,
65 | oracle=market.amm.oracle,
66 | oracle_source=market.amm.oracle_source,
67 | )
68 | configs.append(config)
69 |
70 | # Generate Python code
71 | output = """
72 | mainnet_perp_market_configs: list[PerpMarketConfig] = ["""
73 |
74 | for config in configs:
75 | output += f"""
76 | PerpMarketConfig(
77 | symbol="{config.symbol}",
78 | base_asset_symbol="{config.base_asset_symbol}",
79 | market_index={config.market_index},
80 | oracle=Pubkey.from_string("{str(config.oracle)}"),
81 | oracle_source=OracleSource.{config.oracle_source.__class__.__name__}(), # type: ignore
82 | ),"""
83 |
84 | output += "\n]\n"
85 |
86 | return output
87 |
88 |
89 | async def main():
90 | parser = argparse.ArgumentParser()
91 | parser.add_argument("--market-type", choices=["perp", "spot"], required=True)
92 | args = parser.parse_args()
93 |
94 | rpc_url = os.getenv("MAINNET_RPC_ENDPOINT")
95 | if not rpc_url:
96 | raise ValueError("MAINNET_RPC_ENDPOINT is not set")
97 |
98 | drift_client = DriftClient(
99 | AsyncClient(rpc_url),
100 | Wallet.dummy(),
101 | env="mainnet",
102 | account_subscription=AccountSubscriptionConfig("cached"),
103 | )
104 |
105 | await drift_client.subscribe()
106 |
107 | if args.market_type == "perp":
108 | output = await generate_perp_configs(drift_client)
109 | else:
110 | output = await generate_spot_configs(drift_client)
111 |
112 | print(output)
113 | await drift_client.unsubscribe()
114 |
115 |
116 | if __name__ == "__main__":
117 | asyncio.run(main())
118 |
--------------------------------------------------------------------------------
/scripts/math.sh:
--------------------------------------------------------------------------------
1 | pytest -v -s tests/math/*.py
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | git submodule update --init --recursive
2 | cd protocol-v2
3 | yarn && anchor build
--------------------------------------------------------------------------------
/src/driftpy/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.8.56"
2 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/__init__.py:
--------------------------------------------------------------------------------
1 | from .get_accounts import *
2 | from .types import *
3 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/cache/__init__.py:
--------------------------------------------------------------------------------
1 | from .drift_client import *
2 | from .user import *
3 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/cache/user.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from anchorpy import Program
4 | from solders.pubkey import Pubkey
5 | from solana.rpc.commitment import Commitment
6 |
7 | from driftpy.accounts import get_user_account_and_slot
8 | from driftpy.accounts import UserAccountSubscriber, DataAndSlot
9 | from driftpy.types import UserAccount
10 |
11 |
12 | class CachedUserAccountSubscriber(UserAccountSubscriber):
13 | def __init__(
14 | self,
15 | user_pubkey: Pubkey,
16 | program: Program,
17 | commitment: Commitment = "confirmed",
18 | ):
19 | self.program = program
20 | self.commitment = commitment
21 | self.user_pubkey = user_pubkey
22 | self.user_and_slot = None
23 |
24 | async def subscribe(self):
25 | await self.update_cache()
26 |
27 | async def update_cache(self):
28 | user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey)
29 | self.user_and_slot = user_and_slot
30 |
31 | async def fetch(self):
32 | await self.update_cache()
33 |
34 | def update_data(self, data: Optional[DataAndSlot[UserAccount]]):
35 | if data is not None:
36 | if self.user_and_slot is None or data.slot >= self.user_and_slot.slot:
37 | self.user_and_slot = data
38 |
39 | def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
40 | return self.user_and_slot
41 |
42 | def unsubscribe(self):
43 | self.user_and_slot = None
44 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/demo/__init__.py:
--------------------------------------------------------------------------------
1 | from .drift_client import *
2 | from .user import *
3 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/demo/user.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from anchorpy import Program
4 | from solders.pubkey import Pubkey
5 | from solana.rpc.commitment import Commitment
6 |
7 | from driftpy.accounts import get_user_account_and_slot
8 | from driftpy.accounts import UserAccountSubscriber, DataAndSlot
9 | from driftpy.types import UserAccount
10 |
11 |
12 | class DemoUserAccountSubscriber(UserAccountSubscriber):
13 | def __init__(
14 | self,
15 | user_pubkey: Pubkey,
16 | program: Program,
17 | commitment: Commitment = "confirmed",
18 | ):
19 | self.program = program
20 | self.commitment = commitment
21 | self.user_pubkey = user_pubkey
22 | self.user_and_slot = None
23 |
24 | async def subscribe(self):
25 | await self.update_cache()
26 |
27 | async def update_cache(self):
28 | user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey)
29 | self.user_and_slot = user_and_slot
30 |
31 | async def fetch(self):
32 | await self.update_cache()
33 |
34 | def update_data(self, data: Optional[DataAndSlot[UserAccount]]):
35 | if data is not None:
36 | if self.user_and_slot is None or data.slot >= self.user_and_slot.slot:
37 | self.user_and_slot = data
38 |
39 | def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
40 | return self.user_and_slot
41 |
42 | def unsubscribe(self):
43 | self.user_and_slot = None
44 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/grpc/geyser_codegen/README.md:
--------------------------------------------------------------------------------
1 | # Note for driftpy:
2 |
3 | These generated `geyser_pb2` files were generated following
4 | the instructions on https://github.com/jito-labs/jito-python/ and committed here.
5 |
6 | Original readme below:
7 |
8 |
9 | # Python example
10 |
11 | ## DISCLAIMER
12 |
13 | This example can contains errors or be behind of the latest stable version, please use it only as an example of how your subscription can looks like. If you want well tested production ready example, please check our implementation on Rust.
14 |
15 |
16 |
17 | ## Instruction
18 |
19 | This Python example uses [grpc.io](https://grpc.io/) library.
20 | It assumes your CA trust store on your machine allows trust the CA from your RPC endpoint.
21 |
22 | ## Installation
23 |
24 | Create a virtual environment and install its dependencies:
25 | ```bash
26 | $ python -m venv venv
27 | $ . venv/bin/activate
28 | (venv) $ python -m pip install -U pip
29 | (venv) $ python -m pip install -r requirements.txt
30 | ```
31 |
32 | ## Launch the helloworld_geyser
33 |
34 | Print the usage:
35 |
36 | ```bash
37 | (venv) $ python helloworld_geyser.py --help
38 | Usage: helloworld_geyser.py [OPTIONS]
39 |
40 | Simple program to get the latest solana slot number
41 |
42 | Options:
43 | --rpc-fqdn TEXT Fully Qualified domain name of your RPC endpoint
44 | --x-token TEXT x-token to authenticate each gRPC call
45 | --help Show this message and exit.
46 | ```
47 |
48 | - `rpc-fqdn`: is the fully qualified domain name without the `https://`, such as `index.rpcpool.com`.
49 | - `x-token` : is the x-token to authenticate yourself to the RPC node.
50 |
51 | Here is a full example:
52 |
53 | ```bash
54 | (venv) $ python helloworld_geyser.py --rpc-fqdn 'index.rpcpool.com' --x-token '2625ae71-0823-41b3-b3bc-4ff89d762d52'
55 | slot: 264236514
56 |
57 | ```
58 |
59 | **NOTE**: `2625ae71-0823-41b3-b3bc-4ff89d762d52` is a fake x-token, you need to provide your own token.
60 |
61 | ## Generate gRPC service and request signatures
62 |
63 | The library `grpcio` generates the stubs for you.
64 |
65 | From the directory of `helloword_geyser.py` you can generate all the stubs and data types using the following command:
66 |
67 | ```bash
68 | (venv) $ python -m grpc_tools.protoc -I../../yellowstone-grpc-proto/proto/ --python_out=. --pyi_out=. --grpc_python_out=. ../../yellowstone-grpc-proto/proto/*
69 | ```
70 |
71 | This will generate:
72 | - geyser_pb2.py
73 | - geyser_pb2.pyi
74 | - geyser_pb2_grpc.py
75 | - solana_storage_pb2.py
76 | - solana_storage_pb2.pyi
77 | - solana_storage_pb2_grpc.py
78 |
79 | Which you can then import into your code.
80 |
81 |
82 | ## Useful documentation for grpcio authentication process
83 |
84 | - [secure_channel](https://grpc.github.io/grpc/python/grpc.html#create-client-credentials)
85 |
86 | - [extend auth method via call credentials](https://grpc.io/docs/guides/auth/#extending-grpc-to-support-other-authentication-mechanisms)
87 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/grpc/geyser_codegen/solana_storage_pb2_grpc.py:
--------------------------------------------------------------------------------
1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2 | """Client and server classes corresponding to protobuf-defined services."""
3 | import grpc
4 | import warnings
5 |
6 |
7 | GRPC_GENERATED_VERSION = '1.68.1'
8 | GRPC_VERSION = grpc.__version__
9 | _version_not_supported = False
10 |
11 | try:
12 | from grpc._utilities import first_version_is_lower
13 | _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
14 | except ImportError:
15 | _version_not_supported = True
16 |
17 | if _version_not_supported:
18 | raise RuntimeError(
19 | f'The grpc package installed is at version {GRPC_VERSION},'
20 | + f' but the generated code in solana_storage_pb2_grpc.py depends on'
21 | + f' grpcio>={GRPC_GENERATED_VERSION}.'
22 | + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
23 | + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
24 | )
25 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/grpc/geyser_codegen/subscribe_geyser.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import time
4 | from typing import Iterator, Optional
5 |
6 | import geyser_pb2
7 | import geyser_pb2_grpc
8 | import grpc
9 |
10 |
11 | def _triton_sign_request(
12 | callback: grpc.AuthMetadataPluginCallback,
13 | x_token: Optional[str],
14 | error: Optional[Exception],
15 | ):
16 | metadata = (("x-token", x_token),)
17 | return callback(metadata, error)
18 |
19 |
20 | class TritonAuthMetadataPlugin(grpc.AuthMetadataPlugin):
21 | def __init__(self, x_token: str):
22 | self.x_token = x_token
23 |
24 | def __call__(
25 | self,
26 | context: grpc.AuthMetadataContext,
27 | callback: grpc.AuthMetadataPluginCallback,
28 | ):
29 | return _triton_sign_request(callback, self.x_token, None)
30 |
31 |
32 | def create_subscribe_request() -> Iterator[geyser_pb2.SubscribeRequest]:
33 | request = geyser_pb2.SubscribeRequest()
34 |
35 | # Create the account filter
36 | account_filter = geyser_pb2.SubscribeRequestFilterAccounts()
37 |
38 | # Add a specific account to monitor
39 | # Note: This needs to be the base58 encoded public key
40 | account_filter.account.append("dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH")
41 | account_filter.account.append("DRiP8pChV3hr8FkdgPpxpVwQh3dHnTzHgdbn5Z3fmwHc")
42 | account_filter.account.append("Fe4hMZrg7R97ZrbSScWBXUpQwZB9gzBnhodTCGyjkHsG")
43 | account_filter.nonempty_txn_signature = True
44 |
45 | # Copy the filter into the request
46 | request.accounts["account_monitor"].CopyFrom(account_filter)
47 | request.commitment = geyser_pb2.CommitmentLevel.CONFIRMED
48 |
49 | yield request
50 |
51 | # Keep connection alive with pings
52 | while True:
53 | time.sleep(30)
54 | ping_request = geyser_pb2.SubscribeRequest()
55 | ping_request.ping.id = int(time.time())
56 | yield ping_request
57 |
58 |
59 | def handle_subscribe_updates(stub: geyser_pb2_grpc.GeyserStub):
60 | """
61 | Handles the streaming updates from the subscription.
62 | Each update can contain different types of data based on our filters.
63 | """
64 | try:
65 | request_iterator = create_subscribe_request()
66 | print("Starting subscription stream...")
67 |
68 | response_stream = stub.Subscribe(request_iterator)
69 |
70 | for update in response_stream:
71 | if update.HasField("account"):
72 | print("\nAccount Update:")
73 | print(f" Pubkey: {update.account.account.pubkey.hex()}")
74 | print(f" Owner: {update.account.account.owner.hex()}")
75 | print(f" Lamports: {update.account.account.lamports}")
76 | print(f" Slot: {update.account.slot}")
77 | if update.account.account.txn_signature:
78 | print(
79 | f" Transaction: {update.account.account.txn_signature.hex()}"
80 | )
81 |
82 | elif update.HasField("pong"):
83 | print(f"Received pong: {update.pong.id}")
84 |
85 | except grpc.RpcError as e:
86 | logging.error(f"RPC error occurred: {str(e)}")
87 | raise
88 |
89 |
90 | def run_subscription_client():
91 | rpc_fqdn = os.environ.get("RPC_FDQN")
92 | x_token = os.environ.get("X_TOKEN")
93 |
94 | auth = TritonAuthMetadataPlugin(x_token)
95 | ssl_creds = grpc.ssl_channel_credentials()
96 | call_creds = grpc.metadata_call_credentials(auth)
97 | combined_creds = grpc.composite_channel_credentials(ssl_creds, call_creds)
98 |
99 | with grpc.secure_channel(rpc_fqdn, credentials=combined_creds) as channel:
100 | stub = geyser_pb2_grpc.GeyserStub(channel)
101 | handle_subscribe_updates(stub)
102 |
103 |
104 | if __name__ == "__main__":
105 | logging.basicConfig(level=logging.INFO)
106 | run_subscription_client()
107 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/grpc/user.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from driftpy.accounts.grpc.account_subscriber import GrpcAccountSubscriber
4 | from driftpy.accounts.types import DataAndSlot, UserAccountSubscriber
5 | from driftpy.types import UserAccount
6 |
7 |
8 | class GrpcUserAccountSubscriber(
9 | GrpcAccountSubscriber[UserAccount], UserAccountSubscriber
10 | ):
11 | def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
12 | return self.data_and_slot
13 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/grpc/user_stats.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from driftpy.accounts.grpc.account_subscriber import GrpcAccountSubscriber
4 | from driftpy.accounts.types import DataAndSlot, UserStatsAccountSubscriber
5 | from driftpy.types import UserStatsAccount
6 |
7 |
8 | class GrpcUserStatsAccountSubscriber(
9 | GrpcAccountSubscriber[UserStatsAccount], UserStatsAccountSubscriber
10 | ):
11 | def get_user_stats_account_and_slot(
12 | self,
13 | ) -> Optional[DataAndSlot[UserStatsAccount]]:
14 | return self.data_and_slot
15 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/polling/__init__.py:
--------------------------------------------------------------------------------
1 | from .drift_client import *
2 | from .user import *
3 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/polling/user.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from anchorpy import Program
4 | from solders.pubkey import Pubkey
5 |
6 | from driftpy.accounts import (
7 | UserAccountSubscriber,
8 | DataAndSlot,
9 | get_user_account_and_slot,
10 | )
11 |
12 | from driftpy.accounts.bulk_account_loader import BulkAccountLoader
13 | from driftpy.types import UserAccount
14 |
15 |
16 | class PollingUserAccountSubscriber(UserAccountSubscriber):
17 | def __init__(
18 | self,
19 | user_account_pubkey: Pubkey,
20 | program: Program,
21 | bulk_account_loader: BulkAccountLoader,
22 | ):
23 | self.bulk_account_loader = bulk_account_loader
24 | self.program = program
25 | self.user_account_pubkey = user_account_pubkey
26 | self.data_and_slot: Optional[DataAndSlot[UserAccount]] = None
27 | self.decode = self.program.coder.accounts.decode
28 | self.callback_id = None
29 |
30 | async def subscribe(self):
31 | if self.callback_id is not None:
32 | return
33 |
34 | self.add_to_account_loader()
35 |
36 | if self.data_and_slot is None:
37 | await self.fetch()
38 |
39 | def add_to_account_loader(self):
40 | if self.callback_id is not None:
41 | return
42 |
43 | self.callback_id = self.bulk_account_loader.add_account(
44 | self.user_account_pubkey, self._account_loader_callback
45 | )
46 |
47 | def _account_loader_callback(self, buffer: bytes, slot: int):
48 | if buffer is None:
49 | return
50 |
51 | if self.data_and_slot is not None and self.data_and_slot.slot >= slot:
52 | return
53 |
54 | account = self.decode(buffer)
55 | self.data_and_slot = DataAndSlot(slot, account)
56 |
57 | async def fetch(self):
58 | data_and_slot = await get_user_account_and_slot(
59 | self.program, self.user_account_pubkey
60 | )
61 | self.update_data(data_and_slot)
62 |
63 | def update_data(self, new_data: Optional[DataAndSlot[UserAccount]]):
64 | if new_data is None:
65 | return
66 |
67 | if self.data_and_slot is None or new_data.slot >= self.data_and_slot.slot:
68 | self.data_and_slot = new_data
69 |
70 | def unsubscribe(self):
71 | if self.callback_id is None:
72 | return
73 |
74 | self.bulk_account_loader.remove_account(
75 | self.user_account_pubkey, self.callback_id
76 | )
77 |
78 | self.callback_id = None
79 |
80 | def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
81 | return self.data_and_slot
82 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/types.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from dataclasses import dataclass
3 | from typing import Awaitable, Callable, Generic, Optional, Sequence, TypeVar, Union
4 |
5 | from solana.rpc.commitment import Commitment
6 | from solana.rpc.types import MemcmpOpts
7 | from solders.pubkey import Pubkey
8 |
9 | from driftpy.types import (
10 | OraclePriceData,
11 | OracleSource,
12 | PerpMarketAccount,
13 | SpotMarketAccount,
14 | StateAccount,
15 | UserAccount,
16 | UserStatsAccount,
17 | )
18 |
19 | T = TypeVar("T")
20 |
21 |
22 | @dataclass
23 | class DataAndSlot(Generic[T]):
24 | slot: int
25 | data: T
26 |
27 |
28 | @dataclass
29 | class FullOracleWrapper:
30 | pubkey: Pubkey
31 | oracle_source: OracleSource
32 | oracle_price_data_and_slot: Optional[DataAndSlot[OraclePriceData]]
33 |
34 |
35 | @dataclass
36 | class BufferAndSlot:
37 | slot: int
38 | buffer: bytes
39 |
40 |
41 | @dataclass
42 | class WebsocketProgramAccountOptions:
43 | filters: Sequence[MemcmpOpts]
44 | commitment: Commitment
45 | encoding: str
46 |
47 |
48 | @dataclass
49 | class GrpcProgramAccountOptions:
50 | filters: Sequence[MemcmpOpts]
51 | commitment: Commitment
52 |
53 |
54 | UpdateCallback = Callable[[str, DataAndSlot[UserAccount]], Awaitable[None]]
55 |
56 | MarketUpdateCallback = Callable[
57 | [str, DataAndSlot[Union[PerpMarketAccount, SpotMarketAccount]]], Awaitable[None]
58 | ]
59 |
60 |
61 | class DriftClientAccountSubscriber:
62 | @abstractmethod
63 | async def subscribe(self):
64 | pass
65 |
66 | @abstractmethod
67 | def unsubscribe(self):
68 | pass
69 |
70 | @abstractmethod
71 | async def fetch(self):
72 | pass
73 |
74 | @abstractmethod
75 | def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
76 | pass
77 |
78 | @abstractmethod
79 | def get_perp_market_and_slot(
80 | self, market_index: int
81 | ) -> Optional[DataAndSlot[PerpMarketAccount]]:
82 | pass
83 |
84 | @abstractmethod
85 | def get_spot_market_and_slot(
86 | self, market_index: int
87 | ) -> Optional[DataAndSlot[SpotMarketAccount]]:
88 | pass
89 |
90 | @abstractmethod
91 | def get_oracle_price_data_and_slot(
92 | self, oracle: Pubkey
93 | ) -> Optional[DataAndSlot[OraclePriceData]]:
94 | pass
95 |
96 | @abstractmethod
97 | def get_market_accounts_and_slots(self) -> list[DataAndSlot[PerpMarketAccount]]:
98 | pass
99 |
100 | @abstractmethod
101 | def get_spot_market_accounts_and_slots(
102 | self,
103 | ) -> list[DataAndSlot[SpotMarketAccount]]:
104 | pass
105 |
106 |
107 | class UserAccountSubscriber:
108 | @abstractmethod
109 | async def subscribe(self):
110 | pass
111 |
112 | @abstractmethod
113 | def unsubscribe(self):
114 | pass
115 |
116 | @abstractmethod
117 | async def fetch(self):
118 | pass
119 |
120 | async def update_data(self, data: Optional[DataAndSlot[UserAccount]]):
121 | pass
122 |
123 | @abstractmethod
124 | def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
125 | pass
126 |
127 |
128 | class UserStatsAccountSubscriber:
129 | @abstractmethod
130 | async def subscribe(self):
131 | pass
132 |
133 | @abstractmethod
134 | def unsubscribe(self):
135 | pass
136 |
137 | @abstractmethod
138 | async def fetch(self):
139 | pass
140 |
141 | @abstractmethod
142 | def get_user_stats_account_and_slot(
143 | self,
144 | ) -> Optional[DataAndSlot[UserStatsAccount]]:
145 | pass
146 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/ws/__init__.py:
--------------------------------------------------------------------------------
1 | from .drift_client import *
2 | from .user import *
3 | from .user_stats import *
4 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/ws/account_subscriber.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Callable, Generic, Optional, TypeVar, cast
3 |
4 | import websockets
5 | import websockets.exceptions # force eager imports
6 | from anchorpy.program.core import Program
7 | from solana.rpc.commitment import Commitment
8 | from solana.rpc.websocket_api import SolanaWsClientProtocol, connect
9 | from solders.pubkey import Pubkey
10 |
11 | from driftpy.accounts import (
12 | DataAndSlot,
13 | UserAccountSubscriber,
14 | UserStatsAccountSubscriber,
15 | get_account_data_and_slot,
16 | )
17 | from driftpy.types import get_ws_url
18 |
19 | T = TypeVar("T")
20 |
21 |
22 | class WebsocketAccountSubscriber(
23 | UserAccountSubscriber, UserStatsAccountSubscriber, Generic[T]
24 | ):
25 | def __init__(
26 | self,
27 | pubkey: Pubkey,
28 | program: Program,
29 | commitment: Commitment = Commitment("confirmed"),
30 | decode: Optional[Callable[[bytes], T]] = None,
31 | initial_data: Optional[DataAndSlot] = None,
32 | ):
33 | self.program = program
34 | self.commitment = commitment
35 | self.pubkey = pubkey
36 | self.data_and_slot = initial_data or None
37 | self.task = None
38 | self.decode = (
39 | decode if decode is not None else self.program.coder.accounts.decode
40 | )
41 | self.ws: Optional[SolanaWsClientProtocol] = None
42 |
43 | async def subscribe(self):
44 | if self.data_and_slot is None:
45 | await self.fetch()
46 |
47 | self.task = asyncio.create_task(self.subscribe_ws())
48 | return self.task
49 |
50 | def is_subscribed(self):
51 | return self.task is not None
52 |
53 | async def subscribe_ws(self):
54 | endpoint = self.program.provider.connection._provider.endpoint_uri
55 | ws_endpoint = get_ws_url(endpoint)
56 |
57 | async for ws in connect(ws_endpoint):
58 | try:
59 | self.ws = cast(SolanaWsClientProtocol, ws)
60 | await self.ws.account_subscribe(
61 | self.pubkey,
62 | commitment=self.commitment,
63 | encoding="base64",
64 | )
65 | first_resp = await ws.recv()
66 | subscription_id = cast(int, first_resp[0].result)
67 |
68 | async for msg in ws:
69 | try:
70 | slot = int(msg[0].result.context.slot) # type: ignore
71 |
72 | if msg[0].result.value is None:
73 | continue
74 |
75 | account_bytes = cast(bytes, msg[0].result.value.data) # type: ignore
76 | decoded_data = self.decode(account_bytes)
77 | self.update_data(DataAndSlot(slot, decoded_data))
78 | except Exception:
79 | print("Error processing account data")
80 | break
81 | await self.ws.account_unsubscribe(subscription_id)
82 | except websockets.exceptions.ConnectionClosed:
83 | print("Websocket closed, reconnecting...")
84 | continue
85 |
86 | async def fetch(self):
87 | new_data = await get_account_data_and_slot(
88 | self.pubkey, self.program, self.commitment, self.decode
89 | )
90 | self.update_data(new_data)
91 |
92 | def update_data(self, new_data: Optional[DataAndSlot[T]]):
93 | if new_data is None:
94 | return
95 |
96 | if self.data_and_slot is None or new_data.slot >= self.data_and_slot.slot:
97 | self.data_and_slot = new_data
98 |
99 | async def unsubscribe(self):
100 | if self.task:
101 | self.task.cancel()
102 | self.task = None
103 | if self.ws:
104 | await self.ws.close()
105 | self.ws = None
106 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/ws/user.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from driftpy.accounts import DataAndSlot
4 | from driftpy.types import UserAccount
5 |
6 | from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber
7 | from driftpy.accounts.types import UserAccountSubscriber
8 |
9 |
10 | class WebsocketUserAccountSubscriber(
11 | WebsocketAccountSubscriber[UserAccount], UserAccountSubscriber
12 | ):
13 | def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
14 | return self.data_and_slot
15 |
--------------------------------------------------------------------------------
/src/driftpy/accounts/ws/user_stats.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from driftpy.accounts import DataAndSlot
4 | from driftpy.accounts.types import UserStatsAccountSubscriber
5 | from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber
6 | from driftpy.types import UserStatsAccount
7 |
8 |
9 | class WebsocketUserStatsAccountSubscriber(
10 | WebsocketAccountSubscriber[UserStatsAccount], UserStatsAccountSubscriber
11 | ):
12 | def get_user_stats_account_and_slot(
13 | self,
14 | ) -> Optional[DataAndSlot[UserStatsAccount]]:
15 | return self.data_and_slot
16 |
--------------------------------------------------------------------------------
/src/driftpy/address_lookup_table.py:
--------------------------------------------------------------------------------
1 | from solana.rpc.async_api import AsyncClient
2 | from solders.pubkey import Pubkey
3 | from solders.address_lookup_table_account import AddressLookupTableAccount
4 |
5 | LOOKUP_TABLE_META_SIZE = 56
6 |
7 |
8 | async def get_address_lookup_table(
9 | connection: AsyncClient, pubkey: Pubkey
10 | ) -> AddressLookupTableAccount:
11 | account_info = await connection.get_account_info(pubkey)
12 | return decode_address_lookup_table(pubkey, account_info.value.data)
13 |
14 |
15 | def decode_address_lookup_table(pubkey: Pubkey, data: bytes):
16 | data_len = len(data)
17 |
18 | addresses = []
19 | i = LOOKUP_TABLE_META_SIZE
20 | while i < data_len:
21 | addresses.append(Pubkey.from_bytes(data[i : i + 32]))
22 | i += 32
23 |
24 | return AddressLookupTableAccount(pubkey, addresses)
25 |
--------------------------------------------------------------------------------
/src/driftpy/auction_subscriber/auction_subscriber.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Optional
3 |
4 | from events.events import Events, _EventSlot
5 | from solders.pubkey import Pubkey
6 |
7 | from driftpy.accounts.types import DataAndSlot, WebsocketProgramAccountOptions
8 | from driftpy.accounts.ws.program_account_subscriber import (
9 | WebSocketProgramAccountSubscriber,
10 | )
11 | from driftpy.auction_subscriber.types import AuctionSubscriberConfig
12 | from driftpy.decode.user import decode_user
13 | from driftpy.memcmp import get_user_filter, get_user_with_auction_filter
14 | from driftpy.types import UserAccount
15 |
16 |
17 | class AuctionEvents(Events):
18 | """
19 | Subclass to explicitly declare the on_account_update event
20 | both in __events__ and as a typed attribute.
21 | """
22 |
23 | __events__ = ("on_account_update",)
24 | on_account_update: _EventSlot
25 |
26 |
27 | class AuctionSubscriber:
28 | def __init__(self, config: AuctionSubscriberConfig):
29 | self.drift_client = config.drift_client
30 | self.commitment = (
31 | config.commitment
32 | if config.commitment is not None
33 | else self.drift_client.connection.commitment
34 | )
35 | self.resub_timeout_ms = config.resub_timeout_ms
36 | self.subscriber: Optional[WebSocketProgramAccountSubscriber] = None
37 | self.event_emitter = AuctionEvents()
38 |
39 | async def on_update(self, account_pubkey: str, data: DataAndSlot[UserAccount]):
40 | self.event_emitter.on_account_update(
41 | data.data,
42 | Pubkey.from_string(account_pubkey),
43 | data.slot, # type: ignore
44 | )
45 |
46 | async def subscribe(self):
47 | if self.subscriber is None:
48 | filters = (get_user_filter(), get_user_with_auction_filter())
49 | options = WebsocketProgramAccountOptions(filters, self.commitment, "base64")
50 | self.subscriber = WebSocketProgramAccountSubscriber(
51 | "AuctionSubscriber",
52 | self.drift_client.program,
53 | options,
54 | self.on_update,
55 | decode_user,
56 | self.resub_timeout_ms,
57 | )
58 |
59 | if self.subscriber.subscribed:
60 | return
61 |
62 | await self.subscriber.subscribe()
63 |
64 | def unsubscribe(self):
65 | if self.subscriber is None:
66 | return
67 | asyncio.create_task(self.subscriber.unsubscribe())
68 |
--------------------------------------------------------------------------------
/src/driftpy/auction_subscriber/grpc_auction_subscriber.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from events.events import Events, _EventSlot
4 | from solders.pubkey import Pubkey
5 |
6 | from driftpy.accounts.grpc.account_subscriber import GrpcAccountSubscriber
7 | from driftpy.accounts.grpc.program_account_subscriber import (
8 | GrpcProgramAccountSubscriber,
9 | )
10 | from driftpy.accounts.types import DataAndSlot, GrpcProgramAccountOptions
11 | from driftpy.auction_subscriber.types import GrpcAuctionSubscriberConfig
12 | from driftpy.decode.user import decode_user
13 | from driftpy.memcmp import get_user_filter, get_user_with_auction_filter
14 | from driftpy.types import UserAccount
15 |
16 |
17 | class GrpcAuctionEvents(Events):
18 | __events__ = ("on_account_update",)
19 | on_account_update: _EventSlot
20 |
21 |
22 | class GrpcAuctionSubscriber:
23 | def __init__(self, config: GrpcAuctionSubscriberConfig):
24 | self.config = config
25 | self.drift_client = config.drift_client
26 | self.commitment = (
27 | config.commitment
28 | if config.commitment is not None
29 | else self.drift_client.connection.commitment
30 | )
31 | self.subscribers: list[GrpcAccountSubscriber] = []
32 | self.event_emitter = GrpcAuctionEvents()
33 |
34 | async def on_update(self, account_pubkey: str, data: DataAndSlot[UserAccount]):
35 | self.event_emitter.on_account_update(
36 | data.data, Pubkey.from_string(account_pubkey), data.slot
37 | )
38 |
39 | async def subscribe(self):
40 | if self.subscribers:
41 | return
42 |
43 | filters = (get_user_filter(), get_user_with_auction_filter())
44 | options = GrpcProgramAccountOptions(filters, self.commitment)
45 | self.subscriber = GrpcProgramAccountSubscriber(
46 | grpc_config=self.config.grpc_config,
47 | subscription_name="AuctionSubscriber",
48 | program=self.drift_client.program,
49 | options=options,
50 | on_update=self.on_update,
51 | decode=decode_user,
52 | )
53 | await self.subscriber.subscribe()
54 |
55 | def unsubscribe(self):
56 | if not self.subscribers:
57 | return
58 |
59 | for subscriber in self.subscribers:
60 | asyncio.create_task(subscriber.unsubscribe())
61 | self.subscribers.clear()
62 |
--------------------------------------------------------------------------------
/src/driftpy/auction_subscriber/types.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from solana.rpc.commitment import Commitment
4 |
5 | from driftpy.drift_client import DriftClient
6 | from driftpy.types import GrpcConfig
7 |
8 |
9 | class AuctionSubscriberConfig:
10 | def __init__(
11 | self,
12 | drift_client: DriftClient,
13 | commitment: Optional[Commitment] = None,
14 | resub_timeout_ms: Optional[int] = None,
15 | ):
16 | self.drift_client = drift_client
17 | self.commitment = commitment
18 | self.resub_timeout_ms = resub_timeout_ms
19 |
20 |
21 | class GrpcAuctionSubscriberConfig(AuctionSubscriberConfig):
22 | def __init__(
23 | self,
24 | drift_client: DriftClient,
25 | grpc_config: GrpcConfig,
26 | commitment: Optional[Commitment] = None,
27 | ):
28 | super().__init__(drift_client, commitment)
29 | self.grpc_config = grpc_config
30 |
--------------------------------------------------------------------------------
/src/driftpy/constants/__init__.py:
--------------------------------------------------------------------------------
1 | from driftpy.constants.numeric_constants import *
2 |
--------------------------------------------------------------------------------
/src/driftpy/decode/pull_oracle.py:
--------------------------------------------------------------------------------
1 | from driftpy.decode.user import (
2 | read_uint8,
3 | read_int32_le,
4 | read_bigint64le,
5 | )
6 | from driftpy.types import PriceUpdateV2, PriceFeedMessage
7 |
8 |
9 | def decode_pull_oracle(buffer: bytes) -> PriceUpdateV2:
10 | offset = 8
11 |
12 | offset += 32 # skip 0
13 |
14 | verification_level_flag = read_uint8(buffer, offset)
15 | if verification_level_flag & 0x1:
16 | offset += 1 # skip verification_level Full
17 | else:
18 | offset += 2 # skip verificaton_level Partial
19 |
20 | offset += 32 # skip feed_id
21 |
22 | price = read_bigint64le(buffer, offset, True)
23 | offset += 8
24 |
25 | conf = read_bigint64le(buffer, offset, False)
26 | offset += 8
27 |
28 | exponent = read_int32_le(buffer, offset, True)
29 | offset += 4
30 |
31 | offset += 8 # skip publish_time
32 | offset += 8 # skip prev_publish_time
33 |
34 | ema_price = read_bigint64le(buffer, offset, True)
35 | offset += 8
36 |
37 | ema_conf = read_bigint64le(buffer, offset, False)
38 | offset += 8
39 |
40 | posted_slot = read_bigint64le(buffer, offset, False)
41 |
42 | price_feed_message = PriceFeedMessage(price, conf, exponent, ema_price, ema_conf)
43 |
44 | return PriceUpdateV2(price_feed_message, posted_slot)
45 |
--------------------------------------------------------------------------------
/src/driftpy/decode/user_stat.py:
--------------------------------------------------------------------------------
1 | from solders.pubkey import Pubkey
2 |
3 | from driftpy.decode.user import (
4 | read_bigint64le,
5 | read_int32_le,
6 | read_uint8,
7 | read_uint16_le,
8 | )
9 | from driftpy.types import UserFees, UserStatsAccount
10 |
11 |
12 | def decode_user_stat(buffer: bytes) -> UserStatsAccount:
13 | offset = 8
14 | authority = Pubkey(buffer[offset : offset + 32])
15 | offset += 32
16 | referrer = Pubkey(buffer[offset : offset + 32])
17 | offset += 32
18 |
19 | total_fee_paid = read_bigint64le(buffer, offset, False)
20 | offset += 8
21 | total_fee_rebate = read_bigint64le(buffer, offset, False)
22 | offset += 8
23 | total_token_discount = read_bigint64le(buffer, offset, False)
24 | offset += 8
25 | total_referee_discount = read_bigint64le(buffer, offset, False)
26 | offset += 8
27 | total_referrer_reward = read_bigint64le(buffer, offset, False)
28 | offset += 8
29 | current_epoch_referrer_reward = read_bigint64le(buffer, offset, False)
30 | offset += 8
31 |
32 | user_fees = UserFees(
33 | total_fee_paid,
34 | total_fee_rebate,
35 | total_token_discount,
36 | total_referee_discount,
37 | total_referrer_reward,
38 | current_epoch_referrer_reward,
39 | )
40 |
41 | next_epoch_ts = read_bigint64le(buffer, offset, True)
42 | offset += 8
43 |
44 | maker_volume_30d = read_bigint64le(buffer, offset, False)
45 | offset += 8
46 |
47 | taker_volume_30d = read_bigint64le(buffer, offset, False)
48 | offset += 8
49 |
50 | filler_volume_30d = read_bigint64le(buffer, offset, False)
51 | offset += 8
52 |
53 | last_maker_volume_30d_ts = read_bigint64le(buffer, offset, True)
54 | offset += 8
55 |
56 | last_taker_volume_30d_ts = read_bigint64le(buffer, offset, True)
57 | offset += 8
58 |
59 | last_filler_volume_30d_ts = read_bigint64le(buffer, offset, True)
60 | offset += 8
61 |
62 | if_staked_quote_asset_amount = read_bigint64le(buffer, offset, False)
63 | offset += 8
64 |
65 | number_of_sub_accounts = read_uint16_le(buffer, offset)
66 | offset += 2
67 |
68 | number_of_sub_accounts_created = read_uint16_le(buffer, offset)
69 | offset += 2
70 |
71 | referrer_status = read_uint8(buffer, offset)
72 | is_referrer = (referrer_status & 0x1) == 1
73 | offset += 1
74 |
75 | disable_update_perp_bid_ask_twap = read_uint8(buffer, offset) == 1
76 | offset += 1
77 |
78 | offset += 1
79 |
80 | fuel_overflow_status = read_uint8(buffer, offset)
81 | offset += 1
82 |
83 | fuel_insurance = read_int32_le(buffer, offset, False)
84 | offset += 4
85 |
86 | fuel_deposits = read_int32_le(buffer, offset, False)
87 | offset += 4
88 |
89 | fuel_borrows = read_int32_le(buffer, offset, False)
90 | offset += 4
91 |
92 | fuel_positions = read_int32_le(buffer, offset, False)
93 | offset += 4
94 |
95 | fuel_taker = read_int32_le(buffer, offset, False)
96 | offset += 4
97 |
98 | fuel_maker = read_int32_le(buffer, offset, False)
99 | offset += 4
100 |
101 | if_staked_gov_token_amount = read_bigint64le(buffer, offset, False)
102 | offset += 8
103 |
104 | last_fuel_if_bonus_update_ts = read_int32_le(buffer, offset, False)
105 | offset += 4
106 |
107 | padding = [0] * 12
108 |
109 | return UserStatsAccount(
110 | authority,
111 | referrer,
112 | user_fees,
113 | next_epoch_ts,
114 | maker_volume_30d,
115 | taker_volume_30d,
116 | filler_volume_30d,
117 | last_maker_volume_30d_ts,
118 | last_taker_volume_30d_ts,
119 | last_filler_volume_30d_ts,
120 | if_staked_quote_asset_amount,
121 | number_of_sub_accounts,
122 | number_of_sub_accounts_created,
123 | is_referrer,
124 | disable_update_perp_bid_ask_twap,
125 | fuel_overflow_status,
126 | fuel_insurance,
127 | fuel_deposits,
128 | fuel_borrows,
129 | fuel_positions,
130 | fuel_taker,
131 | fuel_maker,
132 | if_staked_gov_token_amount,
133 | last_fuel_if_bonus_update_ts,
134 | padding,
135 | )
136 |
--------------------------------------------------------------------------------
/src/driftpy/decode/utils.py:
--------------------------------------------------------------------------------
1 | def decode_name(bytes_list: list[int]):
2 | byte_array = bytes(bytes_list)
3 | return byte_array.decode("utf-8").strip()
4 |
--------------------------------------------------------------------------------
/src/driftpy/dlob/client_types.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from dataclasses import dataclass
3 |
4 | from driftpy.dlob.dlob import DLOB
5 | from driftpy.drift_client import DriftClient
6 |
7 |
8 | class DLOBSource(ABC):
9 | @abstractmethod
10 | async def get_DLOB(self, slot: int) -> DLOB:
11 | pass
12 |
13 |
14 | class SlotSource(ABC):
15 | @abstractmethod
16 | def get_slot(self) -> int:
17 | pass
18 |
19 |
20 | @dataclass
21 | class DLOBClientConfig:
22 | drift_client: DriftClient
23 | dlob_source: DLOBSource
24 | slot_source: SlotSource
25 | update_frequency: int
26 |
--------------------------------------------------------------------------------
/src/driftpy/dlob/dlob_helpers.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union
2 | from driftpy.types import (
3 | MarketType,
4 | PerpMarketAccount,
5 | SpotMarketAccount,
6 | StateAccount,
7 | is_variant,
8 | )
9 |
10 |
11 | def get_maker_rebate(
12 | market_type: MarketType,
13 | state_account: StateAccount,
14 | market_account: Union[PerpMarketAccount, SpotMarketAccount],
15 | ):
16 | if is_variant(market_type, "Perp"):
17 | maker_rebate_numerator = state_account.perp_fee_structure.fee_tiers[
18 | 0
19 | ].maker_rebate_numerator
20 | maker_rebate_denominator = state_account.perp_fee_structure.fee_tiers[
21 | 0
22 | ].maker_rebate_denominator
23 | else:
24 | maker_rebate_numerator = state_account.spot_fee_structure.fee_tiers[
25 | 0
26 | ].maker_rebate_numerator
27 | maker_rebate_denominator = state_account.spot_fee_structure.fee_tiers[
28 | 0
29 | ].maker_rebate_denominator
30 |
31 | fee_adjustment = (
32 | market_account.fee_adjustment
33 | if market_account.fee_adjustment is not None
34 | else 0
35 | )
36 | if fee_adjustment != 0:
37 | maker_rebate_numerator += (maker_rebate_numerator * fee_adjustment) // 100
38 |
39 | return maker_rebate_numerator, maker_rebate_denominator
40 |
41 |
42 | def get_node_lists(order_lists):
43 | from driftpy.dlob.dlob_node import MarketNodeLists
44 |
45 | order_lists: Dict[str, Dict[int, MarketNodeLists]]
46 |
47 | for _, node_lists in order_lists.get("perp", {}).items():
48 | yield node_lists.resting_limit["bid"]
49 | yield node_lists.resting_limit["ask"]
50 | yield node_lists.taking_limit["bid"]
51 | yield node_lists.taking_limit["ask"]
52 | yield node_lists.market["bid"]
53 | yield node_lists.market["ask"]
54 | yield node_lists.floating_limit["bid"]
55 | yield node_lists.floating_limit["ask"]
56 | yield node_lists.trigger["above"]
57 | yield node_lists.trigger["below"]
58 |
59 | for _, node_lists in order_lists.get("spot", {}).items():
60 | yield node_lists.resting_limit["bid"]
61 | yield node_lists.resting_limit["ask"]
62 | yield node_lists.taking_limit["bid"]
63 | yield node_lists.taking_limit["ask"]
64 | yield node_lists.market["bid"]
65 | yield node_lists.market["ask"]
66 | yield node_lists.floating_limit["bid"]
67 | yield node_lists.floating_limit["ask"]
68 | yield node_lists.trigger["above"]
69 | yield node_lists.trigger["below"]
70 |
--------------------------------------------------------------------------------
/src/driftpy/drift_user_stats.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 |
4 | from solders.pubkey import Pubkey
5 | from solana.rpc.commitment import Commitment, Confirmed
6 |
7 | from driftpy.accounts.types import DataAndSlot
8 | from driftpy.types import ReferrerInfo, UserStatsAccount
9 | from driftpy.accounts.ws.user_stats import WebsocketUserStatsAccountSubscriber
10 | from driftpy.addresses import (
11 | get_user_account_public_key,
12 | get_user_stats_account_public_key,
13 | )
14 |
15 |
16 | @dataclass
17 | class UserStatsSubscriptionConfig:
18 | commitment: Commitment = Confirmed
19 | resub_timeout_ms: Optional[int] = None
20 | initial_data: Optional[DataAndSlot[UserStatsAccount]] = None
21 |
22 |
23 | class DriftUserStats:
24 | def __init__(
25 | self,
26 | drift_client,
27 | user_stats_account_pubkey: Pubkey,
28 | config: UserStatsSubscriptionConfig,
29 | ):
30 | self.drift_client = drift_client
31 | self.user_stats_account_pubkey = user_stats_account_pubkey
32 | self.account_subscriber = WebsocketUserStatsAccountSubscriber(
33 | user_stats_account_pubkey,
34 | drift_client.program,
35 | config.commitment,
36 | initial_data=config.initial_data,
37 | )
38 | self.subscribed = False
39 |
40 | async def subscribe(self) -> bool:
41 | if self.subscribed:
42 | return
43 |
44 | await self.account_subscriber.subscribe()
45 | self.subscribed = True
46 |
47 | return self.subscribed
48 |
49 | async def fetch_accounts(self):
50 | await self.account_subscriber.fetch()
51 |
52 | def unsubscribe(self):
53 | self.account_subscriber.unsubscribe()
54 |
55 | def get_account_and_slot(self) -> DataAndSlot[UserStatsAccount]:
56 | return self.account_subscriber.get_user_stats_account_and_slot()
57 |
58 | def get_account(self) -> UserStatsAccount:
59 | return self.account_subscriber.get_user_stats_account_and_slot().data
60 |
61 | def get_referrer_info(self) -> Optional[ReferrerInfo]:
62 | if self.get_account().referrer == Pubkey.default():
63 | return None
64 | else:
65 | return ReferrerInfo(
66 | get_user_account_public_key(
67 | self.drift_client.program_id, self.get_account().referrer, 0
68 | ),
69 | get_user_stats_account_public_key(
70 | self.drift_client.program_id, self.get_account().referrer
71 | ),
72 | )
73 |
--------------------------------------------------------------------------------
/src/driftpy/events/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/src/driftpy/events/__init__.py
--------------------------------------------------------------------------------
/src/driftpy/events/event_list.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from dataclasses import dataclass
3 |
4 | from driftpy.events.types import WrappedEvent, EventSubscriptionOrderDirection, SortFn
5 |
6 |
7 | @dataclass
8 | class Node:
9 | event: WrappedEvent
10 | next: Optional[any] = None
11 | prev: Optional[any] = None
12 |
13 |
14 | class EventList:
15 | def __init__(
16 | self,
17 | max_size: int,
18 | sort_fn: SortFn,
19 | order_direction: EventSubscriptionOrderDirection,
20 | ):
21 | self.size = 0
22 | self.max_size = max_size
23 | self.sort_fn = sort_fn
24 | self.order_direction = order_direction
25 | self.head = None
26 | self.tail = None
27 |
28 | def insert(self, event: WrappedEvent) -> None:
29 | self.size += 1
30 | new_node = Node(event)
31 | if self.head is None:
32 | self.head = self.tail = new_node
33 | return
34 |
35 | halt_condition = -1 if self.order_direction == "asc" else 1
36 |
37 | if self.sort_fn(self.head.event, new_node.event) == halt_condition:
38 | self.head.prev = new_node
39 | new_node.next = self.head
40 | self.head = new_node
41 | else:
42 | current_node = self.head
43 | while (
44 | current_node.next is not None
45 | and self.sort_fn(current_node.next.event, new_node.event)
46 | != halt_condition
47 | ):
48 | current_node = current_node.next
49 |
50 | new_node.next = current_node.next
51 | if current_node.next is not None:
52 | new_node.next.prev = new_node
53 | else:
54 | self.tail = new_node
55 |
56 | current_node.next = new_node
57 | new_node.prev = current_node
58 |
59 | if self.size > self.max_size:
60 | self.detach()
61 |
62 | def detach(self) -> None:
63 | node = self.tail
64 | if node.prev is not None:
65 | node.prev.next = node.next
66 | else:
67 | self.head = node.next
68 |
69 | if node.next is not None:
70 | node.next.prev = node.prev
71 | else:
72 | self.tail = node.prev
73 |
74 | self.size -= 1
75 |
76 | def to_array(self) -> list[WrappedEvent]:
77 | return list(self)
78 |
79 | def __iter__(self):
80 | node = self.head
81 | while node:
82 | yield node.event
83 | node = node.next
84 |
--------------------------------------------------------------------------------
/src/driftpy/events/event_subscriber.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from anchorpy import EventParser, Program
4 | from events import Events as EventEmitter
5 | from solana.rpc.async_api import AsyncClient
6 | from solders.signature import Signature
7 |
8 | from driftpy.events.event_list import EventList
9 | from driftpy.events.parse import parse_logs
10 | from driftpy.events.sort import get_sort_fn
11 | from driftpy.events.tx_event_cache import TxEventCache
12 | from driftpy.events.types import EventSubscriptionOptions, EventType, WrappedEvent
13 |
14 |
15 | class EventSubscriber:
16 | def __init__(
17 | self,
18 | connection: AsyncClient,
19 | program: Program,
20 | options: EventSubscriptionOptions = EventSubscriptionOptions.default(),
21 | ):
22 | self.connection = connection
23 | self.program = program
24 | self.options = options
25 | self.subscribed = False
26 | self.event_list_map: dict[EventType:EventList] = {}
27 | for event_type in self.options.event_types:
28 | self.event_list_map[event_type] = EventList(
29 | self.options.max_events_per_type,
30 | get_sort_fn(self.options.order_by, self.options.order_dir),
31 | self.options.order_dir,
32 | )
33 | self.event_parser = EventParser(self.program.program_id, self.program.coder)
34 | self.log_provider = self.options.get_log_provider(connection)
35 | self.tx_event_cache = TxEventCache(self.options.max_tx)
36 | self.event_emitter = EventEmitter(("new_event",))
37 |
38 | def subscribe(self):
39 | self.log_provider.subscribe(self.handle_tx_logs)
40 | self.subscribed = True
41 |
42 | def unsubscribe(self):
43 | self.log_provider.unsubscribe()
44 | self.subscribed = False
45 |
46 | def handle_tx_logs(
47 | self,
48 | tx_sig: Signature,
49 | slot: int,
50 | logs: list[str],
51 | ):
52 | if self.tx_event_cache.has(str(tx_sig)):
53 | return
54 |
55 | wrapped_events = self.parse_events_from_logs(tx_sig, slot, logs)
56 | for wrapped_event in wrapped_events:
57 | self.event_list_map.get(wrapped_event.event_type).insert(wrapped_event)
58 |
59 | for wrapped_event in wrapped_events:
60 | self.event_emitter.new_event(wrapped_event)
61 |
62 | self.tx_event_cache.add(str(tx_sig), wrapped_events)
63 |
64 | def parse_events_from_logs(self, tx_sig: Signature, slot: int, logs: list[str]):
65 | wrapped_events = []
66 |
67 | events = parse_logs(self.program, logs)
68 |
69 | for index, event in enumerate(events):
70 | if event.name in self.event_list_map:
71 | wrapped_event = WrappedEvent(
72 | event_type=event.name,
73 | tx_sig=tx_sig,
74 | slot=slot,
75 | tx_sig_index=index,
76 | data=event.data,
77 | )
78 | wrapped_events.append(wrapped_event)
79 |
80 | return wrapped_events
81 |
82 | def get_event_list(self, event_type: EventType) -> Optional[EventList]:
83 | return self.event_list_map.get(event_type)
84 |
85 | def get_events_array(self, event_type: EventType) -> Optional[list[WrappedEvent]]:
86 | event_list = self.event_list_map.get(event_type)
87 | return None if event_list is None else event_list.to_array()
88 |
89 | def get_events_by_tx(self, tx_sig: str) -> Optional[list[WrappedEvent]]:
90 | return self.tx_event_cache.get(tx_sig)
91 |
--------------------------------------------------------------------------------
/src/driftpy/events/fetch_logs.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | import jsonrpcclient
4 | from solana.rpc.async_api import AsyncClient
5 | from solana.rpc.commitment import Commitment
6 | from solders.pubkey import Pubkey
7 | from solders.rpc.responses import (
8 | RpcConfirmedTransactionStatusWithSignature,
9 | )
10 | from solders.signature import Signature
11 |
12 |
13 | async def fetch_logs(
14 | connection: AsyncClient,
15 | address: Pubkey,
16 | commitment: Commitment,
17 | before_tx: Signature = None,
18 | until_tx: Signature = None,
19 | limit: int = None,
20 | batch_size: int = None,
21 | ) -> list[(Signature, int, list[str])]:
22 | response = await connection.get_signatures_for_address(
23 | address, before_tx, until_tx, limit, commitment
24 | )
25 |
26 | if response.value is None:
27 | raise Exception("Error with get_signature_for_address")
28 |
29 | signatures = response.value
30 |
31 | sorted_signatures = sorted(signatures, key=lambda x: x.slot)
32 |
33 | filtered_signatures = list(
34 | filter(lambda signature: not signature.err, sorted_signatures)
35 | )
36 |
37 | if len(filtered_signatures) == 0:
38 | return []
39 |
40 | batch_size = batch_size if batch_size is not None else 25
41 | chunked_signatures = chunk(filtered_signatures, batch_size)
42 |
43 | chunked_transactions_logs = await asyncio.gather(
44 | *[
45 | fetch_transactions(connection, signatures, commitment)
46 | for signatures in chunked_signatures
47 | ]
48 | )
49 |
50 | return [
51 | transaction_logs
52 | for transactions_logs in chunked_transactions_logs
53 | for transaction_logs in transactions_logs
54 | ]
55 |
56 |
57 | async def fetch_transactions(
58 | connection: AsyncClient,
59 | signatures: list[RpcConfirmedTransactionStatusWithSignature],
60 | commitment: Commitment,
61 | ) -> list[(Signature, int, list[str])]:
62 | rpc_requests = []
63 | for signature in signatures:
64 | rpc_request = jsonrpcclient.request(
65 | "getTransaction",
66 | (
67 | str(signature.signature),
68 | {"commitment": commitment, "maxSupportedTransactionVersion": 0},
69 | ),
70 | )
71 | rpc_requests.append(rpc_request)
72 |
73 | try:
74 | post = connection._provider.session.post(
75 | connection._provider.endpoint_uri,
76 | json=rpc_requests,
77 | headers={"content-encoding": "gzip"},
78 | )
79 | resp = await asyncio.wait_for(post, timeout=10)
80 | except asyncio.TimeoutError:
81 | print("request to rpc timed out")
82 | return []
83 |
84 | parsed_resp = jsonrpcclient.parse(resp.json())
85 |
86 | if isinstance(parsed_resp, jsonrpcclient.Error):
87 | raise ValueError(f"Error fetching transactions: {parsed_resp.message}")
88 |
89 | response = []
90 | for rpc_result in parsed_resp:
91 | if not isinstance(rpc_result, jsonrpcclient.Ok):
92 | raise ValueError(f"Error fetching transactions - not ok: {rpc_result}")
93 |
94 | if rpc_result.result:
95 | tx_sig = rpc_result.result["transaction"]["signatures"][0]
96 | slot = rpc_result.result["slot"]
97 | logs = rpc_result.result["meta"]["logMessages"]
98 | response.append((tx_sig, slot, logs))
99 |
100 | return response
101 |
102 |
103 | def chunk(array, size):
104 | return [array[i : i + size] for i in range(0, len(array), size)]
105 |
--------------------------------------------------------------------------------
/src/driftpy/events/parse.py:
--------------------------------------------------------------------------------
1 | import binascii
2 | import re
3 | import base64
4 |
5 | from typing import Tuple, Optional
6 | from anchorpy import Program, Event
7 |
8 | DRIFT_PROGRAM_ID: str = "dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH"
9 | DRIFT_PROGRAM_START: str = f"Program {DRIFT_PROGRAM_ID} invoke"
10 | PROGRAM_LOG: str = "Program log: "
11 | PROGRAM_DATA: str = "Program data: "
12 | PROGRAM_LOG_START_INDEX: int = len(PROGRAM_LOG)
13 | PROGRAM_DATA_START_INDEX: int = len(PROGRAM_DATA)
14 |
15 |
16 | class ExecutionContext:
17 | def __init__(self):
18 | self.stack: list[str] = []
19 |
20 | def program(self):
21 | if len(self.stack) == 0:
22 | raise ValueError("Expected the stack to have elements")
23 | return self.stack[-1]
24 |
25 | def push(self, program: str):
26 | self.stack.append(program)
27 |
28 | def pop(self):
29 | if len(self.stack) == 0:
30 | raise ValueError("Expected the stack to have elements")
31 | return self.stack.pop()
32 |
33 |
34 | def parse_logs(program: Program, logs: list[str]) -> list[Event]:
35 | events = []
36 | execution = ExecutionContext()
37 | for log in logs:
38 | if log.startswith("Log truncated"):
39 | break
40 |
41 | event, new_program, did_pop = handle_log(execution, log, program)
42 | if event:
43 | events.append(event)
44 | if new_program:
45 | execution.push(new_program)
46 | if did_pop:
47 | execution.pop()
48 |
49 | return events
50 |
51 |
52 | def handle_log(
53 | execution: ExecutionContext, log: str, program: Program
54 | ) -> Tuple[Optional[Event], Optional[str], bool]:
55 | if len(execution.stack) > 0 and execution.program() == DRIFT_PROGRAM_ID:
56 | return handle_program_log(log, program)
57 | else:
58 | return (None, *handle_system_log(log))
59 |
60 |
61 | def handle_program_log(
62 | log: str, program: Program
63 | ) -> Tuple[Optional[Event], Optional[str], bool]:
64 | if log.startswith(PROGRAM_LOG) or log.startswith(PROGRAM_DATA):
65 | log_str = (
66 | log[PROGRAM_LOG_START_INDEX:]
67 | if log.startswith(PROGRAM_LOG)
68 | else log[PROGRAM_DATA_START_INDEX:]
69 | )
70 | try:
71 | decoded = base64.b64decode(log_str)
72 | except binascii.Error:
73 | return (None, None, False)
74 | if len(decoded) < 8:
75 | return (None, None, False)
76 | event = program.coder.events.parse(decoded)
77 | return (event, None, False)
78 | else:
79 | return (None, *handle_system_log(log))
80 |
81 |
82 | def handle_system_log(log: str) -> Tuple[Optional[str], bool]:
83 | log_start = log.split(":")[0]
84 |
85 | if re.findall(r"Program (.*) success", log_start) != []:
86 | return (None, True)
87 | elif log_start.startswith(DRIFT_PROGRAM_START):
88 | return (DRIFT_PROGRAM_ID, False)
89 | elif "invoke" in log_start:
90 | return ("cpi", False)
91 | else:
92 | return (None, False)
93 |
--------------------------------------------------------------------------------
/src/driftpy/events/polling_log_provider.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from solana.rpc.async_api import AsyncClient
3 | from solana.rpc.commitment import Commitment
4 |
5 | from solders.pubkey import Pubkey
6 |
7 | from driftpy.events.fetch_logs import fetch_logs
8 | from driftpy.events.types import LogProvider, LogProviderCallback
9 |
10 |
11 | class PollingLogProvider(LogProvider):
12 | def __init__(
13 | self,
14 | connection: AsyncClient,
15 | address: Pubkey,
16 | commitment: Commitment,
17 | frequency: float,
18 | batch_size: int = 25,
19 | ):
20 | self.connection = connection
21 | self.address = address
22 | self.commitment = commitment
23 | self.frequency = frequency
24 | self.batch_size = batch_size
25 | self.task = None
26 | self.most_recent_tx = None
27 |
28 | def subscribe(self, callback: LogProviderCallback):
29 | if not self.is_subscribed():
30 |
31 | async def fetch():
32 | first_fetch = True
33 | while True:
34 | try:
35 | txs_logs = await fetch_logs(
36 | self.connection,
37 | self.address,
38 | self.commitment,
39 | None,
40 | self.most_recent_tx,
41 | 1 if first_fetch else None,
42 | self.batch_size,
43 | )
44 |
45 | for signature, slot, logs in txs_logs:
46 | callback(signature, slot, logs)
47 |
48 | first_fetch = False
49 | except Exception as e:
50 | print("Error fetching logs", e)
51 |
52 | await asyncio.sleep(self.frequency)
53 |
54 | self.task = asyncio.create_task(fetch())
55 |
56 | def is_subscribed(self) -> bool:
57 | return self.task is not None
58 |
59 | def unsubscribe(self):
60 | if self.is_subscribed():
61 | self.task.cancel()
62 | self.task = None
63 |
--------------------------------------------------------------------------------
/src/driftpy/events/sort.py:
--------------------------------------------------------------------------------
1 | from driftpy.events.types import (
2 | WrappedEvent,
3 | EventSubscriptionOrderBy,
4 | EventSubscriptionOrderDirection,
5 | SortFn,
6 | )
7 |
8 |
9 | def client_sort_asc_fn() -> int:
10 | return -1
11 |
12 |
13 | def client_sort_desc_fn() -> int:
14 | return 1
15 |
16 |
17 | def blockchain_sort_fn(current_event: WrappedEvent, new_event: WrappedEvent) -> int:
18 | if current_event.slot == new_event.slot:
19 | return -1 if current_event.tx_sig_index < new_event.tx_sig_index else 1
20 |
21 | return -1 if current_event.slot <= new_event.slot else 1
22 |
23 |
24 | def get_sort_fn(
25 | order_by: EventSubscriptionOrderBy, order_dir: EventSubscriptionOrderDirection
26 | ) -> SortFn:
27 | if order_by == "client":
28 | return client_sort_asc_fn if order_dir == "asc" else client_sort_desc_fn
29 |
30 | return blockchain_sort_fn
31 |
--------------------------------------------------------------------------------
/src/driftpy/events/tx_event_cache.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict, List
2 | from dataclasses import dataclass
3 |
4 | from driftpy.events.types import WrappedEvent
5 |
6 |
7 | @dataclass
8 | class Node:
9 | key: str
10 | value: List[WrappedEvent]
11 | next: Optional[any] = None
12 | prev: Optional[any] = None
13 |
14 |
15 | class TxEventCache:
16 | def __init__(self, max_tx: int = 1024):
17 | self.size = 0
18 | self.max_tx = max_tx
19 | self.head = None
20 | self.tail = None
21 | self.cache_map: Dict[str, Node] = {}
22 |
23 | def add(self, key: str, events: List[WrappedEvent]) -> None:
24 | existing_node = self.cache_map.get(key)
25 | if existing_node:
26 | self.detach(existing_node)
27 | self.size -= 1
28 | elif self.size == self.max_tx:
29 | del self.cache_map[self.tail.key]
30 | self.detach(self.tail)
31 | self.size -= 1
32 |
33 | if not self.head:
34 | self.head = self.tail = Node(key, events)
35 | else:
36 | node = Node(key, events, next=self.head)
37 | self.head.prev = node
38 | self.head = node
39 |
40 | self.cache_map[key] = self.head
41 | self.size += 1
42 |
43 | def has(self, key: str) -> bool:
44 | return key in self.cache_map
45 |
46 | def get(self, key: str) -> Optional[List[WrappedEvent]]:
47 | return self.cache_map.get(key).value if key in self.cache_map else None
48 |
49 | def detach(self, node: Node) -> None:
50 | if node.prev is not None:
51 | node.prev.next = node.next
52 | else:
53 | self.head = node.next
54 |
55 | if node.next is not None:
56 | node.next.prev = node.prev
57 | else:
58 | self.tail = node.prev
59 |
60 | def clear(self) -> None:
61 | self.head = None
62 | self.tail = None
63 | self.size = 0
64 | self.cache_map = {}
65 |
--------------------------------------------------------------------------------
/src/driftpy/events/websocket_log_provider.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import cast
3 |
4 | import websockets.exceptions
5 | from solana.rpc.async_api import AsyncClient
6 | from solana.rpc.commitment import Commitment
7 | from solana.rpc.websocket_api import SolanaWsClientProtocol, connect
8 | from solders.pubkey import Pubkey
9 | from solders.rpc.config import RpcTransactionLogsFilterMentions
10 |
11 | from driftpy.events.types import LogProvider, LogProviderCallback
12 | from driftpy.types import get_ws_url
13 |
14 |
15 | class WebsocketLogProvider(LogProvider):
16 | def __init__(
17 | self, connection: AsyncClient, address: Pubkey, commitment: Commitment
18 | ):
19 | self.connection = connection
20 | self.address = address
21 | self.commitment = commitment
22 | self.task = None
23 |
24 | def subscribe(self, callback: LogProviderCallback):
25 | if not self.is_subscribed():
26 | self.task = asyncio.create_task(self.subscribe_ws(callback))
27 |
28 | async def subscribe_ws(self, callback: LogProviderCallback):
29 | endpoint = self.connection._provider.endpoint_uri
30 | if endpoint.startswith("http"):
31 | ws_endpoint = get_ws_url(endpoint)
32 | else:
33 | ws_endpoint = endpoint
34 |
35 | async for ws in connect(ws_endpoint):
36 | ws: SolanaWsClientProtocol
37 | try:
38 | await ws.logs_subscribe(
39 | RpcTransactionLogsFilterMentions(self.address),
40 | self.commitment,
41 | )
42 |
43 | first_resp = await ws.recv()
44 | subscription_id = cast(int, first_resp[0].result)
45 |
46 | async for msg in ws:
47 | try:
48 | slot = msg[0].result.context.slot
49 | signature = msg[0].result.value.signature
50 | logs = msg[0].result.value.logs
51 |
52 | if msg[0].result.value.err:
53 | continue
54 |
55 | callback(signature, slot, logs)
56 | except Exception as e:
57 | print("Error processing event data", e)
58 | break
59 | await ws.account_unsubscribe(subscription_id)
60 | except websockets.exceptions.ConnectionClosed:
61 | print("Websocket closed, reconnecting...")
62 | continue
63 |
64 | def is_subscribed(self) -> bool:
65 | return self.task is not None
66 |
67 | def unsubscribe(self):
68 | if self.is_subscribed():
69 | self.task.cancel()
70 | self.task = None
71 |
--------------------------------------------------------------------------------
/src/driftpy/idl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/src/driftpy/idl/__init__.py
--------------------------------------------------------------------------------
/src/driftpy/idl/pyth.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "0.1.0",
3 | "name": "pyth",
4 | "instructions": [
5 | {
6 | "name": "initialize",
7 | "accounts": [
8 | {
9 | "name": "price",
10 | "isMut": true,
11 | "isSigner": false
12 | }
13 | ],
14 | "args": [
15 | {
16 | "name": "price",
17 | "type": "i64"
18 | },
19 | {
20 | "name": "expo",
21 | "type": "i32"
22 | },
23 | {
24 | "name": "conf",
25 | "type": "u64"
26 | }
27 | ]
28 | },
29 | {
30 | "name": "setPrice",
31 | "accounts": [
32 | {
33 | "name": "price",
34 | "isMut": true,
35 | "isSigner": false
36 | }
37 | ],
38 | "args": [
39 | {
40 | "name": "price",
41 | "type": "i64"
42 | }
43 | ]
44 | },
45 | {
46 | "name": "setPriceInfo",
47 | "accounts": [
48 | {
49 | "name": "price",
50 | "isMut": true,
51 | "isSigner": false
52 | }
53 | ],
54 | "args": [
55 | {
56 | "name": "price",
57 | "type": "i64"
58 | },
59 | {
60 | "name": "conf",
61 | "type": "u64"
62 | },
63 | {
64 | "name": "slot",
65 | "type": "u64"
66 | }
67 | ]
68 | },
69 | {
70 | "name": "setTwap",
71 | "accounts": [
72 | {
73 | "name": "price",
74 | "isMut": true,
75 | "isSigner": false
76 | }
77 | ],
78 | "args": [
79 | {
80 | "name": "twap",
81 | "type": "i64"
82 | }
83 | ]
84 | }
85 | ],
86 | "types": [
87 | {
88 | "name": "PriceStatus",
89 | "type": {
90 | "kind": "enum",
91 | "variants": [
92 | {
93 | "name": "Unknown"
94 | },
95 | {
96 | "name": "Trading"
97 | },
98 | {
99 | "name": "Halted"
100 | },
101 | {
102 | "name": "Auction"
103 | }
104 | ]
105 | }
106 | },
107 | {
108 | "name": "CorpAction",
109 | "type": {
110 | "kind": "enum",
111 | "variants": [
112 | {
113 | "name": "NoCorpAct"
114 | }
115 | ]
116 | }
117 | },
118 | {
119 | "name": "PriceType",
120 | "type": {
121 | "kind": "enum",
122 | "variants": [
123 | {
124 | "name": "Unknown"
125 | },
126 | {
127 | "name": "Price"
128 | },
129 | {
130 | "name": "TWAP"
131 | },
132 | {
133 | "name": "Volatility"
134 | }
135 | ]
136 | }
137 | }
138 | ]
139 | }
--------------------------------------------------------------------------------
/src/driftpy/idl/sequence_enforcer.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "0.1.0",
3 | "name": "sequence_enforcer",
4 | "instructions": [
5 | {
6 | "name": "initialize",
7 | "accounts": [
8 | {
9 | "name": "sequenceAccount",
10 | "isMut": true,
11 | "isSigner": false
12 | },
13 | {
14 | "name": "authority",
15 | "isMut": false,
16 | "isSigner": true
17 | },
18 | {
19 | "name": "systemProgram",
20 | "isMut": false,
21 | "isSigner": false
22 | }
23 | ],
24 | "args": [
25 | {
26 | "name": "bump",
27 | "type": "u8"
28 | },
29 | {
30 | "name": "sym",
31 | "type": "string"
32 | }
33 | ]
34 | },
35 | {
36 | "name": "resetSequenceNumber",
37 | "accounts": [
38 | {
39 | "name": "sequenceAccount",
40 | "isMut": true,
41 | "isSigner": false
42 | },
43 | {
44 | "name": "authority",
45 | "isMut": false,
46 | "isSigner": true
47 | }
48 | ],
49 | "args": [
50 | {
51 | "name": "sequenceNum",
52 | "type": "u64"
53 | }
54 | ]
55 | },
56 | {
57 | "name": "checkAndSetSequenceNumber",
58 | "accounts": [
59 | {
60 | "name": "sequenceAccount",
61 | "isMut": true,
62 | "isSigner": false
63 | },
64 | {
65 | "name": "authority",
66 | "isMut": false,
67 | "isSigner": true
68 | }
69 | ],
70 | "args": [
71 | {
72 | "name": "sequenceNum",
73 | "type": "u64"
74 | }
75 | ]
76 | }
77 | ],
78 | "accounts": [
79 | {
80 | "name": "SequenceAccount",
81 | "type": {
82 | "kind": "struct",
83 | "fields": [
84 | {
85 | "name": "sequenceNum",
86 | "type": "u64"
87 | },
88 | {
89 | "name": "authority",
90 | "type": "publicKey"
91 | }
92 | ]
93 | }
94 | }
95 | ],
96 | "errors": [
97 | {
98 | "code": 6000,
99 | "name": "SequenceOutOfOrder",
100 | "msg": "Sequence out of order"
101 | }
102 | ]
103 | }
--------------------------------------------------------------------------------
/src/driftpy/idl/token_faucet.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "0.1.0",
3 | "name": "token_faucet",
4 | "instructions": [
5 | {
6 | "name": "initialize",
7 | "accounts": [
8 | {
9 | "name": "faucetConfig",
10 | "isMut": true,
11 | "isSigner": false
12 | },
13 | {
14 | "name": "admin",
15 | "isMut": true,
16 | "isSigner": true
17 | },
18 | {
19 | "name": "mintAccount",
20 | "isMut": true,
21 | "isSigner": false
22 | },
23 | {
24 | "name": "rent",
25 | "isMut": false,
26 | "isSigner": false
27 | },
28 | {
29 | "name": "systemProgram",
30 | "isMut": false,
31 | "isSigner": false
32 | },
33 | {
34 | "name": "tokenProgram",
35 | "isMut": false,
36 | "isSigner": false
37 | }
38 | ],
39 | "args": []
40 | },
41 | {
42 | "name": "mintToUser",
43 | "accounts": [
44 | {
45 | "name": "faucetConfig",
46 | "isMut": false,
47 | "isSigner": false
48 | },
49 | {
50 | "name": "mintAccount",
51 | "isMut": true,
52 | "isSigner": false
53 | },
54 | {
55 | "name": "userTokenAccount",
56 | "isMut": true,
57 | "isSigner": false
58 | },
59 | {
60 | "name": "mintAuthority",
61 | "isMut": false,
62 | "isSigner": false
63 | },
64 | {
65 | "name": "tokenProgram",
66 | "isMut": false,
67 | "isSigner": false
68 | }
69 | ],
70 | "args": [
71 | {
72 | "name": "amount",
73 | "type": "u64"
74 | }
75 | ]
76 | },
77 | {
78 | "name": "transferMintAuthority",
79 | "accounts": [
80 | {
81 | "name": "faucetConfig",
82 | "isMut": false,
83 | "isSigner": false
84 | },
85 | {
86 | "name": "admin",
87 | "isMut": true,
88 | "isSigner": true
89 | },
90 | {
91 | "name": "mintAccount",
92 | "isMut": true,
93 | "isSigner": false
94 | },
95 | {
96 | "name": "mintAuthority",
97 | "isMut": false,
98 | "isSigner": false
99 | },
100 | {
101 | "name": "tokenProgram",
102 | "isMut": false,
103 | "isSigner": false
104 | }
105 | ],
106 | "args": []
107 | }
108 | ],
109 | "accounts": [
110 | {
111 | "name": "FaucetConfig",
112 | "type": {
113 | "kind": "struct",
114 | "fields": [
115 | {
116 | "name": "admin",
117 | "type": "publicKey"
118 | },
119 | {
120 | "name": "mint",
121 | "type": "publicKey"
122 | },
123 | {
124 | "name": "mintAuthority",
125 | "type": "publicKey"
126 | },
127 | {
128 | "name": "mintAuthorityNonce",
129 | "type": "u8"
130 | }
131 | ]
132 | }
133 | }
134 | ],
135 | "errors": [
136 | {
137 | "code": 6000,
138 | "name": "InvalidMintAccountAuthority",
139 | "msg": "Program not mint authority"
140 | }
141 | ]
142 | }
--------------------------------------------------------------------------------
/src/driftpy/indicative_quotes/__init__.py:
--------------------------------------------------------------------------------
1 | from .indicative_quotes_sender import IndicativeQuotesSender, Quote
2 |
3 | __all__ = ["IndicativeQuotesSender", "Quote"]
4 |
--------------------------------------------------------------------------------
/src/driftpy/keypair.py:
--------------------------------------------------------------------------------
1 | from solders.keypair import Keypair
2 | import os
3 | import json
4 | import base58
5 |
6 |
7 | def load_keypair(private_key):
8 | # try to load privateKey as a filepath
9 | if os.path.exists(private_key):
10 | with open(private_key, "r") as file:
11 | private_key = file.read().strip()
12 |
13 | key_bytes = None
14 | if "[" in private_key and "]" in private_key:
15 | key_bytes = bytes(json.loads(private_key))
16 | elif "," in private_key:
17 | key_bytes = bytes(map(int, private_key.split(",")))
18 | else:
19 | private_key = private_key.replace(" ", "")
20 | key_bytes = base58.b58decode(private_key)
21 |
22 | return Keypair.from_bytes(key_bytes)
23 |
--------------------------------------------------------------------------------
/src/driftpy/market_map/grpc_sub.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, TypeVar
2 |
3 | from solana.rpc.commitment import Commitment
4 |
5 | from driftpy.accounts.grpc.program_account_subscriber import (
6 | GrpcProgramAccountSubscriber,
7 | )
8 | from driftpy.accounts.types import GrpcProgramAccountOptions, MarketUpdateCallback
9 | from driftpy.market_map.market_map import MarketMap
10 | from driftpy.memcmp import get_market_type_filter
11 | from driftpy.types import GrpcConfig, market_type_to_string
12 |
13 | T = TypeVar("T")
14 |
15 |
16 | class GrpcSubscription:
17 | def __init__(
18 | self,
19 | grpc_config: GrpcConfig,
20 | market_map: MarketMap,
21 | commitment: Commitment,
22 | on_update: MarketUpdateCallback,
23 | decode: Optional[Callable[[bytes], T]] = None,
24 | ):
25 | self.grpc_config = grpc_config
26 | self.market_map = market_map
27 | self.commitment = commitment
28 | self.on_update = on_update
29 | self.subscriber = None
30 | self.decode = decode
31 |
32 | async def subscribe(self):
33 | if not self.subscriber:
34 | filters = (get_market_type_filter(self.market_map.market_type),)
35 | options = GrpcProgramAccountOptions(filters, self.commitment)
36 | self.subscriber = GrpcProgramAccountSubscriber(
37 | subscription_name=f"{market_type_to_string(self.market_map.market_type)}MarketMap",
38 | program=self.market_map.program,
39 | grpc_config=self.grpc_config,
40 | on_update=self.on_update,
41 | options=options,
42 | decode=self.decode,
43 | )
44 |
45 | await self.subscriber.subscribe()
46 |
47 | async def unsubscribe(self):
48 | if not self.subscriber:
49 | return
50 | await self.subscriber.unsubscribe()
51 | self.subscriber = None
52 |
--------------------------------------------------------------------------------
/src/driftpy/market_map/market_map_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 |
4 | from anchorpy import Program
5 | from solana.rpc.async_api import AsyncClient
6 | from solana.rpc.commitment import Commitment
7 |
8 | from driftpy.types import GrpcConfig, MarketType
9 |
10 |
11 | @dataclass
12 | class WebsocketConfig:
13 | resub_timeout_ms: Optional[int] = None
14 | commitment: Optional[Commitment] = None
15 |
16 |
17 | @dataclass
18 | class MarketMapConfig:
19 | program: Program
20 | market_type: MarketType # perp market map or spot market map
21 | subscription_config: WebsocketConfig
22 | connection: AsyncClient
23 |
24 |
25 | @dataclass
26 | class GrpcMarketMapConfig:
27 | program: Program
28 | market_type: MarketType # perp market map or spot market map
29 | grpc_config: GrpcConfig
30 | connection: AsyncClient
31 |
--------------------------------------------------------------------------------
/src/driftpy/market_map/websocket_sub.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, TypeVar
2 |
3 | from solana.rpc.commitment import Commitment
4 |
5 | from driftpy.accounts.types import MarketUpdateCallback, WebsocketProgramAccountOptions
6 | from driftpy.accounts.ws.program_account_subscriber import (
7 | WebSocketProgramAccountSubscriber,
8 | )
9 | from driftpy.memcmp import get_market_type_filter
10 | from driftpy.types import market_type_to_string
11 |
12 | T = TypeVar("T")
13 |
14 |
15 | class WebsocketSubscription:
16 | def __init__(
17 | self,
18 | market_map,
19 | commitment: Commitment,
20 | on_update: MarketUpdateCallback,
21 | resub_timeout_ms: Optional[int] = None,
22 | decode: Optional[Callable[[bytes], T]] = None,
23 | ):
24 | self.market_map = market_map
25 | self.commitment = commitment
26 | self.on_update = on_update
27 | self.resub_timeout_ms = resub_timeout_ms
28 | self.subscriber = None
29 | self.decode = decode
30 |
31 | async def subscribe(self):
32 | if not self.subscriber:
33 | filters = (get_market_type_filter(self.market_map.market_type),)
34 | options = WebsocketProgramAccountOptions(filters, self.commitment, "base64")
35 | self.subscriber = WebSocketProgramAccountSubscriber(
36 | f"{market_type_to_string(self.market_map.market_type)}MarketMap",
37 | self.market_map.program,
38 | options,
39 | self.on_update,
40 | self.decode,
41 | )
42 |
43 | await self.subscriber.subscribe()
44 |
45 | async def unsubscribe(self):
46 | if not self.subscriber:
47 | return
48 | await self.subscriber.unsubscribe()
49 | self.subscriber = None
50 |
--------------------------------------------------------------------------------
/src/driftpy/math/conversion.py:
--------------------------------------------------------------------------------
1 | from driftpy.constants.numeric_constants import PRICE_PRECISION
2 |
3 |
4 | def convert_to_number(big_number: int, precision: int = PRICE_PRECISION) -> float:
5 | if not big_number:
6 | return 0
7 | return big_number // precision + (big_number % precision) / precision
8 |
--------------------------------------------------------------------------------
/src/driftpy/math/exchange_status.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from driftpy.types import (
3 | PerpMarketAccount,
4 | SpotMarketAccount,
5 | StateAccount,
6 | is_one_of_variant,
7 | is_variant,
8 | )
9 |
10 |
11 | class ExchangeStatusValues:
12 | Active = 1
13 | DepositPaused = 2
14 | WithdrawPaused = 4
15 | AmmPaused = 8
16 | FillPaused = 16
17 | LiqPaused = 32
18 | FundingPaused = 64
19 | SettlePnlPaused = 128
20 |
21 |
22 | def exchange_paused(state: StateAccount) -> bool:
23 | return not is_variant(state.exchange_status, "Active")
24 |
25 |
26 | def fill_paused(
27 | state: StateAccount, market: Union[PerpMarketAccount, SpotMarketAccount]
28 | ) -> bool:
29 | return (
30 | state.exchange_status & ExchangeStatusValues.FillPaused
31 | ) == ExchangeStatusValues.FillPaused or is_one_of_variant(
32 | market.status, ["Paused", "FillPaused"]
33 | )
34 |
35 |
36 | def amm_paused(
37 | state: StateAccount, market: Union[PerpMarketAccount, SpotMarketAccount]
38 | ) -> bool:
39 | return (
40 | state.exchange_status & ExchangeStatusValues.AmmPaused
41 | ) == ExchangeStatusValues.AmmPaused or is_one_of_variant(
42 | market.status, ["Paused", "AmmPaused"]
43 | )
44 |
--------------------------------------------------------------------------------
/src/driftpy/math/fuel.py:
--------------------------------------------------------------------------------
1 | from driftpy.types import SpotMarketAccount, PerpMarketAccount
2 | from driftpy.constants.numeric_constants import QUOTE_PRECISION, FUEL_WINDOW
3 |
4 |
5 | def calculate_insurance_fuel_bonus(
6 | spot_market: SpotMarketAccount, token_stake_amount: int, fuel_bonus_numerator: int
7 | ) -> int:
8 | insurance_fund_fuel = (
9 | abs(token_stake_amount) * fuel_bonus_numerator
10 | ) * spot_market.fuel_boost_insurance
11 | insurace_fund_fuel_per_day = insurance_fund_fuel // FUEL_WINDOW
12 | insurance_fund_fuel_scaled = insurace_fund_fuel_per_day // (QUOTE_PRECISION // 10)
13 |
14 | return insurance_fund_fuel_scaled
15 |
16 |
17 | def calculate_spot_fuel_bonus(
18 | spot_market: SpotMarketAccount, signed_token_value: int, fuel_bonus_numerator: int
19 | ) -> int:
20 | spot_fuel_scaled: int
21 |
22 | # dust
23 | if abs(signed_token_value) <= QUOTE_PRECISION:
24 | spot_fuel_scaled = 0
25 | elif signed_token_value > 0:
26 | deposit_fuel = (
27 | abs(signed_token_value) * fuel_bonus_numerator
28 | ) * spot_market.fuel_boost_deposits
29 | deposit_fuel_per_day = deposit_fuel // FUEL_WINDOW
30 | spot_fuel_scaled = deposit_fuel_per_day // (QUOTE_PRECISION // 10)
31 | else:
32 | borrow_fuel = (
33 | abs(signed_token_value) * fuel_bonus_numerator
34 | ) * spot_market.fuel_boost_borrows
35 | borrow_fuel_per_day = borrow_fuel // FUEL_WINDOW
36 | spot_fuel_scaled = borrow_fuel_per_day // (QUOTE_PRECISION // 10)
37 |
38 | return spot_fuel_scaled
39 |
40 |
41 | def calculate_perp_fuel_bonus(
42 | perp_market: PerpMarketAccount, base_asset_value: int, fuel_bonus_numerator: int
43 | ) -> int:
44 | perp_fuel_scaled: int
45 |
46 | # dust
47 | if abs(base_asset_value) <= QUOTE_PRECISION:
48 | perp_fuel_scaled = 0
49 | else:
50 | perp_fuel = (
51 | abs(base_asset_value) * fuel_bonus_numerator
52 | ) * perp_market.fuel_boost_position
53 | perp_fuel_per_day = perp_fuel // FUEL_WINDOW
54 | perp_fuel_scaled = perp_fuel_per_day // (QUOTE_PRECISION // 10)
55 |
56 | return perp_fuel_scaled
57 |
--------------------------------------------------------------------------------
/src/driftpy/math/market.py:
--------------------------------------------------------------------------------
1 | from driftpy.types import OraclePriceData, PerpMarketAccount, PositionDirection
2 |
3 |
4 | def calculate_bid_price(
5 | market: PerpMarketAccount, oracle_price_data: OraclePriceData
6 | ) -> int:
7 | from driftpy.math.amm import calculate_updated_amm_spread_reserves, calculate_price
8 |
9 | (
10 | base_asset_reserve,
11 | quote_asset_reserve,
12 | new_peg,
13 | _,
14 | ) = calculate_updated_amm_spread_reserves(
15 | market.amm, PositionDirection.Short(), oracle_price_data
16 | )
17 |
18 | return calculate_price(base_asset_reserve, quote_asset_reserve, new_peg)
19 |
20 |
21 | def calculate_ask_price(
22 | market: PerpMarketAccount, oracle_price_data: OraclePriceData
23 | ) -> int:
24 | from driftpy.math.amm import calculate_updated_amm_spread_reserves, calculate_price
25 |
26 | (
27 | base_asset_reserve,
28 | quote_asset_reserve,
29 | new_peg,
30 | _,
31 | ) = calculate_updated_amm_spread_reserves(
32 | market.amm, PositionDirection.Long(), oracle_price_data
33 | )
34 |
35 | return calculate_price(base_asset_reserve, quote_asset_reserve, new_peg)
36 |
--------------------------------------------------------------------------------
/src/driftpy/math/spot_market.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from driftpy.math.utils import div_ceil
4 | from driftpy.types import (
5 | OraclePriceData,
6 | SpotBalanceType,
7 | SpotMarketAccount,
8 | is_variant,
9 | )
10 |
11 |
12 | def get_signed_token_amount(amount, balance_type):
13 | return amount if is_variant(balance_type, "Deposit") else -abs(amount)
14 |
15 |
16 | def get_token_amount(
17 | balance: int, spot_market: SpotMarketAccount, balance_type: SpotBalanceType
18 | ) -> int:
19 | precision_decrease = 10 ** (19 - spot_market.decimals)
20 |
21 | if is_variant(balance_type, "Deposit"):
22 | return int(
23 | (balance * spot_market.cumulative_deposit_interest) / precision_decrease
24 | )
25 | else:
26 | return div_ceil(
27 | balance * spot_market.cumulative_borrow_interest, precision_decrease
28 | )
29 |
30 |
31 | def get_token_value(
32 | amount, spot_decimals, oracle_price_data: Union[OraclePriceData, int]
33 | ):
34 | precision_decrease = 10**spot_decimals
35 | if isinstance(oracle_price_data, OraclePriceData):
36 | return amount * oracle_price_data.price // precision_decrease
37 | else:
38 | return amount * oracle_price_data // precision_decrease
39 |
40 |
41 | def cast_to_spot_precision(
42 | amount: Union[float, int], spot_market: SpotMarketAccount
43 | ) -> int:
44 | precision = 10**spot_market.decimals
45 | return int(amount * precision)
46 |
--------------------------------------------------------------------------------
/src/driftpy/math/user_status.py:
--------------------------------------------------------------------------------
1 | from driftpy.types import UserAccount, UserStatus
2 |
3 |
4 | def is_user_protected_maker(user_account: UserAccount) -> bool:
5 | return (user_account.status & UserStatus.PROTECTED_MAKER) > 0
6 |
--------------------------------------------------------------------------------
/src/driftpy/math/utils.py:
--------------------------------------------------------------------------------
1 | def clamp_num(x: int, min_clamp: int, max_clamp: int) -> int:
2 | return max(min_clamp, min(x, max_clamp))
3 |
4 |
5 | def div_ceil(a: int, b: int) -> int:
6 | if b == 0:
7 | return a
8 |
9 | quotient = a // b
10 | remainder = a % b
11 |
12 | if remainder > 0:
13 | quotient += 1
14 |
15 | return quotient
16 |
17 |
18 | def sig_num(x: int) -> int:
19 | return -1 if x < 0 else 1
20 |
21 |
22 | def time_remaining_until_update(now: int, last_update_ts: int, update_period: int) -> int:
23 | if update_period <= 0:
24 | raise ValueError("update_period must be positive")
25 |
26 | time_since = now - last_update_ts
27 |
28 | if update_period == 1:
29 | return max(0, 1 - time_since)
30 |
31 | # Calculate delay-based adjustment
32 | last_delay = last_update_ts % update_period
33 | max_delay = update_period // 3
34 | next_wait = update_period - last_delay
35 |
36 | if last_delay > max_delay:
37 | next_wait = 2 * update_period - last_delay
38 |
39 | return max(0, next_wait - time_since)
40 |
--------------------------------------------------------------------------------
/src/driftpy/memcmp.py:
--------------------------------------------------------------------------------
1 | import base58
2 | from anchorpy.coder.accounts import _account_discriminator
3 | from solana.rpc.types import MemcmpOpts
4 |
5 | from driftpy.name import encode_name
6 | from driftpy.types import MarketType, is_variant
7 |
8 |
9 | def get_user_filter() -> MemcmpOpts:
10 | return MemcmpOpts(0, base58.b58encode(_account_discriminator("User")).decode())
11 |
12 |
13 | def get_non_idle_user_filter() -> MemcmpOpts:
14 | return MemcmpOpts(4350, base58.b58encode(bytes([0])).decode())
15 |
16 |
17 | def get_user_with_auction_filter() -> MemcmpOpts:
18 | return MemcmpOpts(4354, base58.b58encode(bytes([1])).decode())
19 |
20 |
21 | def get_user_with_order_filter() -> MemcmpOpts:
22 | return MemcmpOpts(4352, base58.b58encode(bytes([1])).decode())
23 |
24 |
25 | def get_user_without_order_filter() -> MemcmpOpts:
26 | return MemcmpOpts(4352, base58.b58encode(bytes([0])).decode())
27 |
28 |
29 | def get_user_that_has_been_lp_filter() -> MemcmpOpts:
30 | return MemcmpOpts(4267, base58.b58encode(bytes([99])).decode())
31 |
32 |
33 | def get_user_with_name_filter(name: str) -> MemcmpOpts:
34 | encoded_name_bytes = encode_name(name)
35 | return MemcmpOpts(72, base58.b58encode(bytes(encoded_name_bytes)).decode())
36 |
37 |
38 | def get_users_with_pool_id_filter(pool_id: int) -> MemcmpOpts:
39 | return MemcmpOpts(4356, base58.b58encode(bytes([pool_id])).decode())
40 |
41 |
42 | def get_market_type_filter(market_type: MarketType) -> MemcmpOpts:
43 | if is_variant(market_type, "Perp"):
44 | return MemcmpOpts(
45 | 0, base58.b58encode(_account_discriminator("PerpMarket")).decode()
46 | )
47 | else:
48 | return MemcmpOpts(
49 | 0, base58.b58encode(_account_discriminator("SpotMarket")).decode()
50 | )
51 |
52 |
53 | def get_user_stats_filter() -> MemcmpOpts:
54 | return MemcmpOpts(0, base58.b58encode(_account_discriminator("UserStats")).decode())
55 |
56 |
57 | def get_user_stats_is_referred_filter() -> MemcmpOpts:
58 | # offset 188, bytes for 2
59 | return MemcmpOpts(188, base58.b58encode(bytes([2])).decode())
60 |
61 |
62 | def get_user_stats_is_referred_or_referrer_filter() -> MemcmpOpts:
63 | # offset 188, bytes for 3
64 | return MemcmpOpts(188, base58.b58encode(bytes([3])).decode())
65 |
66 |
67 | def get_signed_msg_user_orders_filter() -> MemcmpOpts:
68 | return MemcmpOpts(
69 | 0, base58.b58encode(_account_discriminator("SignedMsgUserOrders")).decode()
70 | )
71 |
--------------------------------------------------------------------------------
/src/driftpy/name.py:
--------------------------------------------------------------------------------
1 | from struct import pack_into
2 |
3 | MAX_LENGTH = 32
4 |
5 |
6 | def encode_name(name: str) -> list[int]:
7 | if len(name) > 32:
8 | raise Exception("name too long")
9 |
10 | name_bytes = bytearray(32)
11 | pack_into(f"{len(name)}s", name_bytes, 0, name.encode("utf-8"))
12 | offset = len(name)
13 | for _ in range(32 - len(name)):
14 | pack_into("1s", name_bytes, offset, " ".encode("utf-8"))
15 | offset += 1
16 |
17 | str_name_bytes = name_bytes.hex()
18 | name_byte_array = []
19 | for i in range(0, len(str_name_bytes), 2):
20 | name_byte_array.append(int(str_name_bytes[i : i + 2], 16))
21 |
22 | return name_byte_array
23 |
--------------------------------------------------------------------------------
/src/driftpy/oracles/oracle_id.py:
--------------------------------------------------------------------------------
1 | from solders.pubkey import Pubkey
2 |
3 | from driftpy.types import OracleSource, OracleSourceNum
4 |
5 |
6 | def get_oracle_source_num(source: OracleSource) -> int:
7 | source_str = str(source)
8 |
9 | if "Pyth1M" in source_str:
10 | return OracleSourceNum.PYTH_1M
11 | elif "Pyth1K" in source_str:
12 | return OracleSourceNum.PYTH_1K
13 | elif "PythPull" in source_str:
14 | return OracleSourceNum.PYTH_PULL
15 | elif "Pyth1KPull" in source_str:
16 | return OracleSourceNum.PYTH_1K_PULL
17 | elif "Pyth1MPull" in source_str:
18 | return OracleSourceNum.PYTH_1M_PULL
19 | elif "PythStableCoinPull" in source_str:
20 | return OracleSourceNum.PYTH_STABLE_COIN_PULL
21 | elif "PythStableCoin" in source_str:
22 | return OracleSourceNum.PYTH_STABLE_COIN
23 | elif "PythLazer1K" in source_str:
24 | return OracleSourceNum.PYTH_LAZER_1K
25 | elif "PythLazer1M" in source_str:
26 | return OracleSourceNum.PYTH_LAZER_1M
27 | elif "PythLazerStableCoin" in source_str:
28 | return OracleSourceNum.PYTH_LAZER_STABLE_COIN
29 | elif "PythLazer" in source_str:
30 | return OracleSourceNum.PYTH_LAZER
31 | elif "Pyth" in source_str:
32 | return OracleSourceNum.PYTH
33 | elif "SwitchboardOnDemand" in source_str:
34 | return OracleSourceNum.SWITCHBOARD_ON_DEMAND
35 | elif "Switchboard" in source_str:
36 | return OracleSourceNum.SWITCHBOARD
37 | elif "QuoteAsset" in source_str:
38 | return OracleSourceNum.QUOTE_ASSET
39 | elif "Prelaunch" in source_str:
40 | return OracleSourceNum.PRELAUNCH
41 |
42 | raise ValueError("Invalid oracle source")
43 |
44 |
45 | def get_oracle_id(public_key: Pubkey, source: OracleSource) -> str:
46 | """
47 | Returns the oracle id for a given oracle and source
48 | """
49 | return f"{str(public_key)}-{get_oracle_source_num(source)}"
50 |
--------------------------------------------------------------------------------
/src/driftpy/oracles/strict_oracle_price.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 |
4 | class StrictOraclePrice:
5 | def __init__(self, current: int, twap: Optional[int] = None):
6 | self.current = current
7 | self.twap = twap
8 |
9 | def max(self) -> int:
10 | if self.twap:
11 | return max(self.twap, self.current)
12 | else:
13 | return self.current
14 |
15 | def min(self):
16 | if self.twap:
17 | return min(self.twap, self.current)
18 | else:
19 | return self.current
20 |
--------------------------------------------------------------------------------
/src/driftpy/priority_fees/priority_fee_subscriber.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from dataclasses import dataclass
3 |
4 | import jsonrpcclient
5 | from solana.rpc.async_api import AsyncClient
6 |
7 |
8 | @dataclass
9 | class PriorityFeeConfig:
10 | connection: AsyncClient
11 | frequency_secs: int
12 | addresses: list[str]
13 | slots_to_check: int = 10
14 |
15 |
16 | class PriorityFeeSubscriber:
17 | def __init__(self, config: PriorityFeeConfig):
18 | self.connection = config.connection
19 | self.frequency_ms = config.frequency_secs
20 | self.addresses = config.addresses
21 | self.slots_to_check = config.slots_to_check
22 |
23 | self.latest_priority_fee = 0
24 | self.avg_priority_fee = 0
25 | self.max_priority_fee = 0
26 | self.last_slot_seen = 0
27 | self.subscribed = False
28 |
29 | async def subscribe(self):
30 | if self.subscribed:
31 | return
32 |
33 | self.subscribed = True
34 |
35 | asyncio.create_task(self.poll())
36 |
37 | async def poll(self):
38 | while self.subscribed:
39 | asyncio.create_task(self.load())
40 | await asyncio.sleep(self.frequency_ms)
41 |
42 | async def load(self):
43 | rpc_request = jsonrpcclient.request(
44 | "getRecentPrioritizationFees", [self.addresses]
45 | )
46 |
47 | post = self.connection._provider.session.post(
48 | self.connection._provider.endpoint_uri,
49 | json=rpc_request,
50 | headers={"content-encoding": "gzip"},
51 | )
52 |
53 | resp = await asyncio.wait_for(post, timeout=20)
54 |
55 | parsed_resp = jsonrpcclient.parse(resp.json())
56 |
57 | if isinstance(parsed_resp, jsonrpcclient.Error):
58 | raise ValueError(f"Error fetching priority fees: {parsed_resp.message}")
59 |
60 | if not isinstance(parsed_resp, jsonrpcclient.Ok):
61 | raise ValueError(f"Error fetching priority fees - not ok: {parsed_resp}")
62 |
63 | result = parsed_resp.result
64 |
65 | desc_results = sorted(result, key=lambda x: x["slot"], reverse=True)[
66 | : self.slots_to_check
67 | ]
68 |
69 | if not desc_results:
70 | return
71 |
72 | self.latest_priority_fee = desc_results[0]["prioritizationFee"]
73 | self.last_slot_seen = desc_results[0]["slot"]
74 | self.avg_priority_fee = sum(
75 | item["prioritizationFee"] for item in desc_results
76 | ) / len(desc_results)
77 | self.max_priority_fee = max(item["prioritizationFee"] for item in desc_results)
78 |
79 | async def unsubscribe(self):
80 | if self.subscribed:
81 | self.subscribed = False
82 |
--------------------------------------------------------------------------------
/src/driftpy/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/src/driftpy/py.typed
--------------------------------------------------------------------------------
/src/driftpy/slot/slot_subscriber.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from events import Events as EventEmitter
4 | from solana.rpc.websocket_api import SolanaWsClientProtocol, connect
5 |
6 | from driftpy.dlob.client_types import SlotSource
7 | from driftpy.drift_client import DriftClient
8 | from driftpy.types import get_ws_url
9 |
10 |
11 | class SlotSubscriber(SlotSource):
12 | def __init__(self, drift_client: DriftClient):
13 | self.current_slot = 0
14 | self.subscription_id = None
15 | self.connection = drift_client.connection
16 | self.program = drift_client.program
17 | self.ws = None
18 | self.subscribed = False
19 | self.event_emitter = EventEmitter(("on_slot_change"))
20 | self.event_emitter.on("on_slot_change")
21 |
22 | async def on_slot_change(self, slot_info: int):
23 | self.current_slot = slot_info
24 | self.event_emitter.on_slot_change(slot_info)
25 |
26 | async def subscribe(self):
27 | if self.subscribed:
28 | return
29 | self.task = asyncio.create_task(self.subscribe_ws())
30 | return self.task
31 |
32 | async def subscribe_ws(self):
33 | if self.subscription_id is not None:
34 | return
35 |
36 | current_slot_response = await self.connection.get_slot()
37 | self.current_slot = int(current_slot_response.value)
38 |
39 | endpoint = self.program.provider.connection._provider.endpoint_uri
40 | ws_endpoint = get_ws_url(endpoint)
41 | while True:
42 | try:
43 | async with connect(ws_endpoint) as ws:
44 | self.subscribed = True
45 | self.ws = ws
46 | ws: SolanaWsClientProtocol
47 | self.subscription_id = await ws.slot_subscribe()
48 |
49 | await ws.recv()
50 |
51 | async for msg in ws:
52 | await self.on_slot_change(msg[0].result.slot)
53 |
54 | except Exception as e:
55 | print(f"Error in SlotSubscriber: {e}")
56 | if self.ws:
57 | await self.ws.close()
58 | self.ws = None
59 | await asyncio.sleep(5) # wait a second before we retry
60 |
61 | def get_slot(self) -> int:
62 | return self.current_slot
63 |
64 | async def unsubscribe(self):
65 | if self.ws:
66 | await self.ws.close()
67 | self.ws = None
68 | self.subscribed = False
69 |
--------------------------------------------------------------------------------
/src/driftpy/swift/create_verify_ix.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 | from construct import Int8ul, Int16ul, Struct
4 | from solana.constants import ED25519_PROGRAM_ID
5 | from solders.instruction import Instruction
6 |
7 | ED25519_INSTRUCTION_LEN = 16
8 | SIGNATURE_LEN = 64
9 | PUBKEY_LEN = 32
10 | MAGIC_LEN = 4
11 | MESSAGE_SIZE_LEN = 2
12 |
13 |
14 | def trim_feed_id(feed_id: str) -> str:
15 | if feed_id.startswith("0x"):
16 | return feed_id[2:]
17 | return feed_id
18 |
19 |
20 | def get_feed_id_uint8_array(feed_id: str) -> bytes:
21 | trimmed_feed_id = trim_feed_id(feed_id)
22 | return bytes.fromhex(trimmed_feed_id)
23 |
24 |
25 | def get_ed25519_args_from_hex(
26 | hex_str: str, custom_instruction_index: Optional[int] = None
27 | ) -> Dict[str, bytes]:
28 | cleaned_hex = hex_str[2:] if hex_str.startswith("0x") else hex_str
29 | buffer = bytes.fromhex(cleaned_hex)
30 |
31 | signature_offset = MAGIC_LEN
32 | public_key_offset = signature_offset + SIGNATURE_LEN
33 | message_data_size_offset = public_key_offset + PUBKEY_LEN
34 | message_data_offset = message_data_size_offset + MESSAGE_SIZE_LEN
35 |
36 | signature = buffer[signature_offset : signature_offset + SIGNATURE_LEN]
37 | public_key = buffer[public_key_offset : public_key_offset + PUBKEY_LEN]
38 | message_size = buffer[message_data_size_offset] | (
39 | buffer[message_data_size_offset + 1] << 8
40 | )
41 | message = buffer[message_data_offset : message_data_offset + message_size]
42 |
43 | if len(public_key) != PUBKEY_LEN:
44 | raise ValueError("Invalid public key length")
45 |
46 | if len(signature) != SIGNATURE_LEN:
47 | raise ValueError("Invalid signature length")
48 |
49 | return {
50 | "public_key": public_key,
51 | "signature": signature,
52 | "message": message,
53 | "instruction_index": custom_instruction_index,
54 | }
55 |
56 |
57 | def read_uint16_le(data: bytes, offset: int) -> int:
58 | return data[offset] | (data[offset + 1] << 8)
59 |
60 |
61 | ED25519_INSTRUCTION_LAYOUT = Struct(
62 | "num_signatures" / Int8ul,
63 | "padding" / Int8ul,
64 | "signature_offset" / Int16ul,
65 | "signature_instruction_index" / Int16ul,
66 | "public_key_offset" / Int16ul,
67 | "public_key_instruction_index" / Int16ul,
68 | "message_data_offset" / Int16ul,
69 | "message_data_size" / Int16ul,
70 | "message_instruction_index" / Int16ul,
71 | )
72 |
73 |
74 | def create_minimal_ed25519_verify_ix(
75 | custom_instruction_index: int,
76 | message_offset: int,
77 | custom_instruction_data: bytes,
78 | magic_len: Optional[int] = None,
79 | ) -> Instruction:
80 | signature_offset = message_offset + (MAGIC_LEN if magic_len is None else magic_len)
81 | public_key_offset = signature_offset + SIGNATURE_LEN
82 | message_data_size_offset = public_key_offset + PUBKEY_LEN
83 | message_data_offset = message_data_size_offset + MESSAGE_SIZE_LEN
84 |
85 | message_data_size = read_uint16_le(
86 | custom_instruction_data, message_data_size_offset - message_offset
87 | )
88 |
89 | instruction_data = ED25519_INSTRUCTION_LAYOUT.build(
90 | dict(
91 | num_signatures=1,
92 | padding=0,
93 | signature_offset=signature_offset,
94 | signature_instruction_index=custom_instruction_index,
95 | public_key_offset=public_key_offset,
96 | public_key_instruction_index=custom_instruction_index,
97 | message_data_offset=message_data_offset,
98 | message_data_size=message_data_size,
99 | message_instruction_index=custom_instruction_index,
100 | )
101 | )
102 |
103 | return Instruction(
104 | accounts=[],
105 | program_id=ED25519_PROGRAM_ID,
106 | data=instruction_data,
107 | )
108 |
--------------------------------------------------------------------------------
/src/driftpy/swift/util.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import hashlib
3 | import random
4 | import string
5 |
6 |
7 | def digest_signature(signature: bytes) -> str:
8 | """
9 | Create a SHA-256 hash of a signature and return it as a base64 string.
10 |
11 | Args:
12 | signature: The signature bytes to hash
13 |
14 | Returns:
15 | Base64-encoded SHA-256 hash of the signature
16 | """
17 | hash_object = hashlib.sha256(signature)
18 | return base64.b64encode(hash_object.digest()).decode("utf-8")
19 |
20 |
21 | def generate_signed_msg_uuid(length=8) -> bytes:
22 | """
23 | Generate a random string similar to nanoid with specified length and convert to bytes.
24 |
25 | Args:
26 | length: Length of the random string to generate (default: 8)
27 |
28 | Returns:
29 | Bytes representation of the generated random string
30 | """
31 | chars = string.ascii_letters + string.digits + "-_" # nanoid alphabet
32 | random_id = "".join(random.choice(chars) for _ in range(length))
33 | return random_id.encode("utf-8")
34 |
--------------------------------------------------------------------------------
/src/driftpy/tx/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/src/driftpy/tx/__init__.py
--------------------------------------------------------------------------------
/src/driftpy/tx/fast_tx_sender.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from solders.hash import Hash
4 |
5 | from solana.rpc.async_api import AsyncClient
6 | from solana.rpc.types import TxOpts
7 | from solana.rpc.commitment import Commitment, Confirmed
8 |
9 | from driftpy.tx.standard_tx_sender import StandardTxSender
10 |
11 |
12 | class FastTxSender(StandardTxSender):
13 | """
14 | The FastTxSender will refresh the latest blockhash in the background to save an RPC when building transactions.
15 | """
16 |
17 | def __init__(
18 | self,
19 | connection: AsyncClient,
20 | opts: TxOpts,
21 | blockhash_refresh_interval_secs: int,
22 | blockhash_commitment: Commitment = Confirmed,
23 | ):
24 | super().__init__(connection, opts, blockhash_commitment)
25 | self.blockhash_refresh_interval = blockhash_refresh_interval_secs
26 | self.recent_blockhash = None
27 |
28 | async def subscribe_blockhash(self):
29 | """
30 | Must be called with asyncio.create_task to prevent blocking
31 | """
32 | while True:
33 | try:
34 | blockhash_info = await self.connection.get_latest_blockhash(
35 | self.blockhash_commitment
36 | )
37 | self.recent_blockhash = blockhash_info.value.blockhash
38 | except Exception as e:
39 | print(f"Error in subscribe_blockhash: {e}")
40 | await asyncio.sleep(self.blockhash_refresh_interval)
41 |
42 | async def fetch_latest_blockhash(self) -> Hash:
43 | if self.recent_blockhash is None:
44 | asyncio.create_task(self.subscribe_blockhash())
45 | return await super().get_blockhash()
46 | return self.recent_blockhash
47 |
--------------------------------------------------------------------------------
/src/driftpy/tx/jito_subscriber.py:
--------------------------------------------------------------------------------
1 | raise ImportError(
2 | "The jito_subscriber module is deprecated and has been removed in driftpy."
3 | )
4 |
5 | import asyncio
6 | import random
7 | from typing import Optional, Tuple, Union
8 |
9 | from jito_searcher_client.async_searcher import (
10 | get_async_searcher_client, # type: ignore
11 | )
12 | from jito_searcher_client.generated.searcher_pb2 import (
13 | ConnectedLeadersRequest,
14 | ConnectedLeadersResponse,
15 | GetTipAccountsRequest,
16 | GetTipAccountsResponse,
17 | SubscribeBundleResultsRequest,
18 | ) # type: ignore
19 | from jito_searcher_client.generated.searcher_pb2_grpc import (
20 | SearcherServiceStub, # type: ignore
21 | )
22 | from solana.rpc.async_api import AsyncClient
23 | from solana.rpc.commitment import Confirmed
24 | from solders.keypair import Keypair # type: ignore
25 | from solders.pubkey import Pubkey # type: ignore
26 | from solders.system_program import TransferParams, transfer
27 |
28 |
29 | class JitoSubscriber:
30 | def __init__(
31 | self,
32 | refresh_rate: int,
33 | kp: Keypair,
34 | connection: AsyncClient,
35 | block_engine_url: str,
36 | ):
37 | self.cache = [] # type: ignore
38 | self.refresh_rate = refresh_rate
39 | self.kp = kp
40 | self.searcher_client: Optional[SearcherServiceStub] = None
41 | self.connection = connection
42 | self.tip_accounts: list[Pubkey] = []
43 | self.block_engine_url = block_engine_url
44 | self.bundle_subscription = None
45 |
46 | async def subscribe(self):
47 | self.searcher_client = await get_async_searcher_client(
48 | self.block_engine_url, self.kp
49 | )
50 | asyncio.create_task(self._subscribe())
51 |
52 | async def _subscribe(self):
53 | self.bundle_subscription = self.searcher_client.SubscribeBundleResults(
54 | SubscribeBundleResultsRequest()
55 | )
56 | tip_accounts: GetTipAccountsResponse = (
57 | await self.searcher_client.GetTipAccounts(GetTipAccountsRequest())
58 | ) # type: ignore
59 | for account in tip_accounts.accounts:
60 | self.tip_accounts.append(Pubkey.from_string(account))
61 | while True:
62 | try:
63 | self.cache.clear()
64 | current_slot = (await self.connection.get_slot(Confirmed)).value
65 | leaders: ConnectedLeadersResponse = (
66 | await self.searcher_client.GetConnectedLeaders(
67 | ConnectedLeadersRequest()
68 | )
69 | ) # type: ignore
70 | for slot_list in leaders.connected_validators.values():
71 | slots = slot_list.slots
72 | for slot in slots:
73 | if slot > current_slot:
74 | self.cache.append(slot)
75 | self.cache.sort()
76 |
77 | except Exception as e:
78 | print(e)
79 | await asyncio.sleep(30)
80 | await self._subscribe()
81 | await asyncio.sleep(self.refresh_rate)
82 |
83 | def send_to_jito(self, current_slot: int) -> bool:
84 | for slot in range(current_slot - 5, current_slot + 5):
85 | if slot in self.cache:
86 | return True
87 | return False
88 |
89 | def get_tip_ix(self, signer: Pubkey, tip_amount: int = 1_000_000):
90 | tip_account = random.choice(self.tip_accounts)
91 | transfer_params = TransferParams(
92 | from_pubkey=signer, to_pubkey=tip_account, lamports=tip_amount
93 | )
94 | return transfer(transfer_params)
95 |
96 | async def process_bundle_result(self, uuid: str) -> Tuple[bool, Union[int, str]]:
97 | while True:
98 | bundle_result = await self.bundle_subscription.read() # type: ignore
99 | if bundle_result.bundle_id == uuid:
100 | if bundle_result.HasField("accepted"):
101 | slot = getattr(getattr(bundle_result, "accepted"), "slot")
102 | return True, slot or 0
103 | elif bundle_result.HasField("rejected"):
104 | msg = getattr(
105 | getattr(
106 | getattr(bundle_result, "rejected"), "simulation_failure"
107 | ),
108 | "msg",
109 | )
110 | return False, msg or ""
111 |
--------------------------------------------------------------------------------
/src/driftpy/tx/jito_tx_sender.py:
--------------------------------------------------------------------------------
1 | raise ImportError(
2 | "The jito_tx_sender module is deprecated and has been removed in driftpy. "
3 | )
4 |
5 |
6 | class JitoTxSender:
7 | def __init__(
8 | self,
9 | drift_client,
10 | opts,
11 | block_engine_url,
12 | jito_keypair,
13 | blockhash_commitment,
14 | blockhash_refresh_interval_secs=None,
15 | tip_amount=None,
16 | ):
17 | raise NotImplementedError(
18 | "JitoTxSender is deprecated and has been removed in driftpy."
19 | )
20 |
--------------------------------------------------------------------------------
/src/driftpy/tx/standard_tx_sender.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence
2 |
3 | from solana.rpc.async_api import AsyncClient
4 | from solana.rpc.commitment import Commitment, Confirmed
5 | from solana.rpc.types import TxOpts
6 | from solders.address_lookup_table_account import AddressLookupTableAccount
7 | from solders.hash import Hash
8 | from solders.instruction import Instruction
9 | from solders.keypair import Keypair
10 | from solders.message import MessageV0
11 | from solders.rpc.responses import SendTransactionResp
12 | from solders.transaction import VersionedTransaction
13 |
14 | from driftpy.tx.types import TxSender, TxSigAndSlot
15 |
16 |
17 | class StandardTxSender(TxSender):
18 | def __init__(
19 | self,
20 | connection: AsyncClient,
21 | opts: TxOpts,
22 | blockhash_commitment: Commitment = Confirmed,
23 | ):
24 | self.connection = connection
25 | if opts.skip_confirmation:
26 | raise ValueError("RetryTxSender doesnt support skip confirmation")
27 | self.opts = opts
28 | self.blockhash_commitment = blockhash_commitment
29 |
30 | async def get_blockhash(self) -> Hash:
31 | return (
32 | await self.connection.get_latest_blockhash(self.blockhash_commitment)
33 | ).value.blockhash
34 |
35 | async def fetch_latest_blockhash(self) -> Hash:
36 | return await self.get_blockhash()
37 |
38 | async def get_versioned_tx(
39 | self,
40 | ixs: Sequence[Instruction],
41 | payer: Keypair,
42 | lookup_tables: Sequence[AddressLookupTableAccount],
43 | additional_signers: Optional[Sequence[Keypair]] = None,
44 | ) -> VersionedTransaction:
45 | latest_blockhash = await self.fetch_latest_blockhash()
46 |
47 | msg = MessageV0.try_compile(
48 | payer.pubkey(), ixs, lookup_tables, latest_blockhash
49 | )
50 |
51 | signers = [payer]
52 | if additional_signers is not None:
53 | [signers.append(signer) for signer in additional_signers]
54 |
55 | return VersionedTransaction(msg, signers)
56 |
57 | async def send(self, tx: VersionedTransaction) -> TxSigAndSlot:
58 | raw = bytes(tx)
59 |
60 | body = self.connection._send_raw_transaction_body(raw, self.opts)
61 | resp = await self.connection._provider.make_request(body, SendTransactionResp)
62 |
63 | if not isinstance(resp, SendTransactionResp):
64 | raise Exception(f"Unexpected response from send transaction: {resp}")
65 |
66 | sig = resp.value
67 |
68 | sig_status = await self.connection.confirm_transaction(
69 | sig, self.opts.preflight_commitment
70 | )
71 | slot = sig_status.context.slot
72 |
73 | return TxSigAndSlot(sig, slot)
74 |
75 | async def send_no_confirm(self, tx: VersionedTransaction) -> TxSigAndSlot:
76 | raw = bytes(tx)
77 |
78 | body = self.connection._send_raw_transaction_body(raw, self.opts)
79 | resp = await self.connection._provider.make_request(body, SendTransactionResp)
80 | sig = resp.value
81 |
82 | return TxSigAndSlot(sig, 0)
83 |
--------------------------------------------------------------------------------
/src/driftpy/tx/types.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from dataclasses import dataclass
3 | from typing import Optional, Sequence
4 |
5 | from solders.address_lookup_table_account import AddressLookupTableAccount
6 | from solders.instruction import Instruction
7 | from solders.keypair import Keypair
8 | from solders.signature import Signature
9 | from solders.transaction import VersionedTransaction
10 |
11 |
12 | @dataclass
13 | class TxSigAndSlot:
14 | tx_sig: Signature
15 | slot: int
16 |
17 |
18 | class TxSender:
19 | @abstractmethod
20 | async def get_versioned_tx(
21 | self,
22 | ixs: Sequence[Instruction],
23 | payer: Keypair,
24 | lookup_tables: Sequence[AddressLookupTableAccount],
25 | additional_signers: Optional[Sequence[Keypair]],
26 | ) -> VersionedTransaction:
27 | pass
28 |
29 | @abstractmethod
30 | async def send(self, tx: VersionedTransaction) -> TxSigAndSlot:
31 | pass
32 |
--------------------------------------------------------------------------------
/src/driftpy/user_map/polling_sub.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from driftpy.user_map.types import Subscription
3 |
4 |
5 | class PollingSubscription(Subscription):
6 | def __init__(self, user_map, frequency: float, skip_initial_load: bool = False):
7 | from driftpy.user_map.user_map import UserMap
8 |
9 | self.user_map: UserMap = user_map
10 | self.frequency = frequency
11 | self.skip_initial_load = skip_initial_load
12 | self.timer_task = None
13 |
14 | async def subscribe(self):
15 | if self.timer_task is not None:
16 | return
17 |
18 | if not self.skip_initial_load:
19 | await self.user_map.sync()
20 |
21 | self.timer_task = asyncio.create_task(self._polling_loop())
22 |
23 | async def _polling_loop(self):
24 | if self.frequency == 0:
25 | # We don't want to start this loop if the frequency is zero.
26 | return
27 | while True:
28 | await asyncio.sleep(self.frequency)
29 | await self.user_map.sync()
30 |
31 | async def unsubscribe(self):
32 | if self.timer_task is not None:
33 | self.timer_task.cancel()
34 | self.timer_task = None
35 |
--------------------------------------------------------------------------------
/src/driftpy/user_map/types.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from driftpy.drift_user import DriftUser
3 | from solders.pubkey import Pubkey
4 | from typing import Optional
5 | from enum import Enum
6 |
7 |
8 | class UserMapInterface(ABC):
9 | @abstractmethod
10 | async def subscribe(self) -> None:
11 | pass
12 |
13 | @abstractmethod
14 | async def unsubscribe(self) -> None:
15 | pass
16 |
17 | @abstractmethod
18 | async def add_pubkey(self, user_account_public_key: Pubkey) -> None:
19 | pass
20 |
21 | @abstractmethod
22 | def has(self, key: str) -> bool:
23 | pass
24 |
25 | @abstractmethod
26 | def get(self, key: str) -> Optional[DriftUser]:
27 | pass
28 |
29 | @abstractmethod
30 | async def must_get(self, key: str) -> DriftUser:
31 | pass
32 |
33 | @abstractmethod
34 | def get_user_authority(self, key: str) -> Optional[Pubkey]:
35 | pass
36 |
37 | @abstractmethod
38 | def values(self):
39 | pass
40 |
41 |
42 | class Subscription(ABC):
43 | pass
44 |
45 |
46 | class ConfigType(Enum):
47 | CACHED = "cached"
48 | WEBSOCKET = "websocket"
49 | POLLING = "polling"
50 |
--------------------------------------------------------------------------------
/src/driftpy/user_map/user_map_config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Literal, Optional, Union
3 |
4 | from solana.rpc.async_api import AsyncClient
5 | from solana.rpc.commitment import Commitment
6 |
7 | from driftpy.drift_client import DriftClient
8 |
9 |
10 | @dataclass
11 | class UserAccountFilterCriteria:
12 | # only return users that have open orders
13 | has_open_orders: bool
14 |
15 |
16 | @dataclass
17 | class PollingConfig:
18 | frequency: int
19 | commitment: Optional[Commitment] = None
20 |
21 |
22 | @dataclass
23 | class WebsocketConfig:
24 | resub_timeout_ms: Optional[int] = None
25 | commitment: Optional[Commitment] = None
26 |
27 |
28 | @dataclass
29 | class UserMapConfig:
30 | drift_client: DriftClient
31 | subscription_config: Union[PollingConfig, WebsocketConfig]
32 | # connection object to use specifically for the UserMap.
33 | # If None, will use the drift_client's connection
34 | connection: Optional[AsyncClient] = None
35 | # True to skip the initial load of user_accounts via gPA
36 | skip_initial_load: Optional[bool] = False
37 | # True to include idle users when loading.
38 | # Defaults to false to decrease # of accounts subscribed to
39 | include_idle: Optional[bool] = None
40 |
41 |
42 | @dataclass
43 | class SyncConfig:
44 | type: Literal["default", "paginated"]
45 | chunk_size: Optional[int] = None
46 | concurrency_limit: Optional[int] = None
47 |
48 |
49 | @dataclass
50 | class UserStatsMapConfig:
51 | drift_client: DriftClient
52 | connection: Optional[AsyncClient] = None
53 | sync_config: Optional[SyncConfig] = None
54 |
--------------------------------------------------------------------------------
/src/driftpy/user_map/websocket_sub.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional, TypeVar
2 | from driftpy.accounts.ws.program_account_subscriber import (
3 | WebSocketProgramAccountSubscriber,
4 | )
5 | from driftpy.memcmp import get_user_filter, get_non_idle_user_filter
6 | from driftpy.accounts.types import UpdateCallback, WebsocketProgramAccountOptions
7 |
8 | T = TypeVar("T")
9 |
10 |
11 | class WebsocketSubscription:
12 | def __init__(
13 | self,
14 | user_map,
15 | commitment,
16 | on_update: UpdateCallback,
17 | skip_initial_load: bool = False,
18 | resub_timeout_ms: int = None,
19 | include_idle: bool = False,
20 | decode: Optional[Callable[[bytes], T]] = None,
21 | ):
22 | from driftpy.user_map.user_map import UserMap
23 |
24 | self.user_map: UserMap = user_map
25 | self.commitment = commitment
26 | self.on_update = on_update
27 | self.skip_initial_load = skip_initial_load
28 | self.resub_timeout_ms = resub_timeout_ms
29 | self.include_idle = include_idle
30 | self.subscriber = None
31 | self.decode = decode
32 |
33 | async def subscribe(self):
34 | if not self.subscriber:
35 | filters = (get_user_filter(),)
36 | if not self.include_idle:
37 | filters += (get_non_idle_user_filter(),)
38 | options = WebsocketProgramAccountOptions(filters, self.commitment, "base64")
39 | self.subscriber = WebSocketProgramAccountSubscriber(
40 | "UserMap",
41 | self.user_map.drift_client.program,
42 | options,
43 | self.on_update,
44 | self.decode,
45 | )
46 |
47 | await self.subscriber.subscribe()
48 |
49 | if not self.skip_initial_load:
50 | await self.user_map.sync()
51 |
52 | async def unsubscribe(self):
53 | if not self.subscriber:
54 | return
55 | await self.subscriber.unsubscribe()
56 | self.subscriber = None
57 |
--------------------------------------------------------------------------------
/src/driftpy/vaults/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Some simple helpers for interacting with the vaults program.
3 |
4 | For a complete vaults SDK, please see https://github.com/drift-labs/drift-vaults
5 | """
6 |
7 | from driftpy.vaults.helpers import (
8 | fetch_all_vault_depositors,
9 | filter_vault_depositors,
10 | get_all_vaults,
11 | get_depositor_info,
12 | get_vault_by_name,
13 | get_vault_depositors,
14 | get_vault_stats,
15 | get_vaults_program,
16 | )
17 | from driftpy.vaults.vault_client import VaultClient
18 |
19 | __all__ = [
20 | "get_vaults_program",
21 | "get_all_vaults",
22 | "get_vault_by_name",
23 | "get_vault_depositors",
24 | "get_vault_stats",
25 | "get_depositor_info",
26 | "filter_vault_depositors",
27 | "fetch_all_vault_depositors",
28 | "VaultClient",
29 | ]
30 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/tests/__init__.py
--------------------------------------------------------------------------------
/tests/ci/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/tests/ci/__init__.py
--------------------------------------------------------------------------------
/tests/decode/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/tests/decode/__init__.py
--------------------------------------------------------------------------------
/tests/decode/decode_stat.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import time
3 |
4 | from pathlib import Path
5 | from pytest import fixture, mark
6 |
7 | from anchorpy import Idl, Program
8 |
9 | from solders.pubkey import Pubkey
10 |
11 | import driftpy
12 | from driftpy.types import UserStatsAccount
13 | from driftpy.decode.user_stat import decode_user_stat
14 |
15 | from tests.decode.stat_decode_strings import stats
16 |
17 |
18 | @fixture(scope="session")
19 | def program() -> Program:
20 | file = Path(str(driftpy.__path__[0]) + "/idl/drift.json")
21 | with file.open() as f:
22 | raw = file.read_text()
23 | idl = Idl.from_json(raw)
24 | return Program(
25 | idl,
26 | Pubkey.from_string("dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH"),
27 | )
28 |
29 |
30 | @mark.asyncio
31 | async def test_user_stat_decode(program: Program):
32 | total_anchor_time: int = 0
33 | total_custom_time: int = 0
34 |
35 | for index, stat in enumerate(stats):
36 | stat_bytes = base64.b64decode(stat)
37 | (anchor_time, custom_time) = user_stats_decode(program, stat_bytes, index)
38 | total_anchor_time += anchor_time
39 | total_custom_time += custom_time
40 |
41 | print("Total anchor time:", total_anchor_time)
42 | print("Total custom time:", total_custom_time)
43 |
44 |
45 | def user_stats_decode(program: Program, buffer: bytes, index: int):
46 | print("Benchmarking user stats decode: ", index)
47 |
48 | anchor_start_ts = int(time.time() * 1_000)
49 | anchor_user_stats: UserStatsAccount = program.coder.accounts.decode(buffer)
50 | anchor_end_ts = int(time.time() * 1_000)
51 | anchor_time = anchor_end_ts - anchor_start_ts
52 |
53 | custom_start_ts = int(time.time() * 1_000)
54 | custom_user_stats = decode_user_stat(buffer)
55 | custom_end_ts = int(time.time() * 1_000)
56 | custom_time = custom_end_ts - custom_start_ts
57 |
58 | assert str(anchor_user_stats.authority) == str(custom_user_stats.authority)
59 | assert str(anchor_user_stats.referrer) == str(custom_user_stats.referrer)
60 | assert (
61 | anchor_user_stats.fees.total_fee_paid == custom_user_stats.fees.total_fee_paid
62 | )
63 | assert (
64 | anchor_user_stats.fees.total_fee_rebate
65 | == custom_user_stats.fees.total_fee_rebate
66 | )
67 | assert (
68 | anchor_user_stats.fees.total_token_discount
69 | == custom_user_stats.fees.total_token_discount
70 | )
71 | assert (
72 | anchor_user_stats.fees.total_referee_discount
73 | == custom_user_stats.fees.total_referee_discount
74 | )
75 | assert (
76 | anchor_user_stats.fees.total_referrer_reward
77 | == custom_user_stats.fees.total_referrer_reward
78 | )
79 | assert (
80 | anchor_user_stats.fees.current_epoch_referrer_reward
81 | == custom_user_stats.fees.current_epoch_referrer_reward
82 | )
83 | assert anchor_user_stats.next_epoch_ts == custom_user_stats.next_epoch_ts
84 | assert anchor_user_stats.maker_volume30d == custom_user_stats.maker_volume30d
85 | assert anchor_user_stats.taker_volume30d == custom_user_stats.taker_volume30d
86 | assert anchor_user_stats.filler_volume30d == custom_user_stats.filler_volume30d
87 | assert (
88 | anchor_user_stats.last_maker_volume30d_ts
89 | == custom_user_stats.last_maker_volume30d_ts
90 | )
91 | assert (
92 | anchor_user_stats.last_taker_volume30d_ts
93 | == custom_user_stats.last_taker_volume30d_ts
94 | )
95 | assert (
96 | anchor_user_stats.last_filler_volume30d_ts
97 | == custom_user_stats.last_filler_volume30d_ts
98 | )
99 | assert (
100 | anchor_user_stats.if_staked_quote_asset_amount
101 | == custom_user_stats.if_staked_quote_asset_amount
102 | )
103 | assert (
104 | anchor_user_stats.number_of_sub_accounts
105 | == custom_user_stats.number_of_sub_accounts
106 | )
107 | assert (
108 | anchor_user_stats.number_of_sub_accounts_created
109 | == custom_user_stats.number_of_sub_accounts_created
110 | )
111 | assert anchor_user_stats.is_referrer == custom_user_stats.is_referrer
112 | assert (
113 | anchor_user_stats.disable_update_perp_bid_ask_twap
114 | == custom_user_stats.disable_update_perp_bid_ask_twap
115 | )
116 |
117 | return (anchor_time, custom_time)
118 |
--------------------------------------------------------------------------------
/tests/decode/dlob_test_helpers.py:
--------------------------------------------------------------------------------
1 | from solders.pubkey import Pubkey
2 |
3 | from typing import Optional
4 | from driftpy.dlob.dlob import DLOB
5 |
6 | from driftpy.types import (
7 | MarketType,
8 | Order,
9 | OrderStatus,
10 | OrderTriggerCondition,
11 | OrderType,
12 | PositionDirection,
13 | )
14 |
15 |
16 | def insert_order_to_dlob(
17 | dlob: DLOB,
18 | user_account: Pubkey,
19 | order_type: OrderType,
20 | market_type: MarketType,
21 | order_id: int,
22 | market_index: int,
23 | price: int,
24 | base_asset_amount: int,
25 | direction: PositionDirection,
26 | auction_start_price: int,
27 | auction_end_price: int,
28 | slot: Optional[int] = None,
29 | max_ts=0,
30 | oracle_price_offset=0,
31 | post_only=False,
32 | auction_duration=10,
33 | ):
34 | slot = slot if slot is not None else 1
35 | order = Order(
36 | slot,
37 | price,
38 | base_asset_amount,
39 | 0,
40 | 0,
41 | 0,
42 | auction_start_price,
43 | auction_end_price,
44 | max_ts,
45 | oracle_price_offset,
46 | order_id,
47 | market_index,
48 | OrderStatus.Open(),
49 | order_type,
50 | market_type,
51 | 0,
52 | PositionDirection.Long(),
53 | direction,
54 | False,
55 | post_only,
56 | False,
57 | OrderTriggerCondition.Above(),
58 | auction_duration,
59 | [0, 0, 0],
60 | )
61 | dlob.insert_order(order, user_account, slot)
62 |
63 |
64 | def insert_trigger_order_to_dlob(
65 | dlob: DLOB,
66 | user_account: Pubkey,
67 | order_type: OrderType,
68 | market_type: MarketType,
69 | order_id: int,
70 | market_index: int,
71 | price: int,
72 | base_asset_amount: int,
73 | direction: PositionDirection,
74 | trigger_price: int,
75 | trigger_condition: OrderTriggerCondition,
76 | auction_start_price: int,
77 | auction_end_price: int,
78 | slot: Optional[int] = None,
79 | max_ts=0,
80 | oracle_price_offset=0,
81 | ):
82 | slot = slot or 1
83 | order = Order(
84 | slot,
85 | price,
86 | base_asset_amount,
87 | 0,
88 | 0,
89 | trigger_price,
90 | auction_start_price,
91 | auction_end_price,
92 | max_ts,
93 | oracle_price_offset,
94 | order_id,
95 | market_index,
96 | OrderStatus.Open(),
97 | order_type,
98 | market_type,
99 | 0,
100 | PositionDirection.Long(),
101 | direction,
102 | False,
103 | False,
104 | True,
105 | trigger_condition,
106 | max_ts,
107 | [0, 0, 0],
108 | )
109 | dlob.insert_order(order, user_account, slot)
110 |
--------------------------------------------------------------------------------
/tests/dlob/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/tests/dlob/__init__.py
--------------------------------------------------------------------------------
/tests/integration/events_parser.py:
--------------------------------------------------------------------------------
1 | from pytest import mark
2 |
3 | from solana.rpc.async_api import AsyncClient
4 | from solders.signature import Signature
5 | from anchorpy import Wallet
6 |
7 | from driftpy.events.event_subscriber import EventSubscriber
8 | from driftpy.drift_client import DriftClient
9 |
10 |
11 | @mark.asyncio
12 | async def test_events_parser():
13 | connection = AsyncClient("https://api.mainnet-beta.solana.com")
14 | drift_client = DriftClient(connection, Wallet.dummy())
15 |
16 | event_subscriber = EventSubscriber(drift_client.connection, drift_client.program)
17 |
18 | tx = await connection.get_transaction(
19 | Signature.from_string(
20 | "3JRzMVquzXXmbV7cPMxiMRQp25scFFkYpsntWjMJQv3i4sMZyVXAGi6X2vAHNkqH1mkNtLvpp4oT6iorzZgLkYNY"
21 | ),
22 | max_supported_transaction_version=0,
23 | )
24 | logs = tx.value.transaction.meta.log_messages
25 |
26 | events = event_subscriber.parse_events_from_logs(
27 | Signature.from_string(
28 | "3JRzMVquzXXmbV7cPMxiMRQp25scFFkYpsntWjMJQv3i4sMZyVXAGi6X2vAHNkqH1mkNtLvpp4oT6iorzZgLkYNY"
29 | ),
30 | tx.value.slot,
31 | logs,
32 | )
33 |
34 | print(events)
35 |
--------------------------------------------------------------------------------
/tests/integration/swb_on_demand.py:
--------------------------------------------------------------------------------
1 | from pytest import mark, approx
2 |
3 | from solana.rpc.async_api import AsyncClient
4 |
5 | from solders.pubkey import Pubkey
6 |
7 | from driftpy.accounts.oracle import (
8 | decode_oracle,
9 | get_oracle_decode_fn,
10 | get_oracle_price_data_and_slot,
11 | SWB_ON_DEMAND_CODER,
12 | )
13 | from driftpy.accounts.types import OracleSource
14 | from driftpy.constants.numeric_constants import SWB_PRECISION
15 |
16 |
17 | @mark.asyncio
18 | async def test_swb_on_demand():
19 | oracle = Pubkey.from_string("EZLBfnznMYKjFmaWYMEdhwnkiQF1WiP9jjTY6M8HpmGE")
20 | oracle_source = OracleSource.SwitchboardOnDemand()
21 | connection = AsyncClient("https://api.mainnet-beta.solana.com")
22 |
23 | oracle_fetched = await get_oracle_price_data_and_slot(
24 | connection, oracle, oracle_source
25 | )
26 |
27 | raw = (await connection.get_account_info(oracle)).value.data
28 | oracle_unstructured = SWB_ON_DEMAND_CODER.accounts.decode(raw).result
29 |
30 | decode = get_oracle_decode_fn(oracle_source)
31 | oracle_decode_fn = decode(raw)
32 |
33 | oracle_decode_oracle = decode_oracle(raw, oracle_source)
34 |
35 | # these two should be identical
36 | assert oracle_decode_oracle.price == oracle_decode_fn.price
37 | assert oracle_decode_oracle.slot == oracle_decode_fn.slot
38 | assert oracle_decode_oracle.confidence == oracle_decode_fn.confidence
39 | assert oracle_decode_oracle.twap == oracle_decode_fn.twap
40 | assert oracle_decode_oracle.twap_confidence == oracle_decode_fn.twap_confidence
41 | assert (
42 | oracle_decode_oracle.has_sufficient_number_of_data_points
43 | == oracle_decode_fn.has_sufficient_number_of_data_points
44 | )
45 |
46 | # potential slight diff from slot drift
47 | assert oracle_fetched.data.price == approx(oracle_decode_oracle.price)
48 | assert oracle_fetched.data.slot == approx(oracle_decode_oracle.slot)
49 | assert oracle_fetched.data.confidence == approx(oracle_decode_oracle.confidence)
50 | assert oracle_fetched.data.twap == oracle_decode_oracle.twap
51 | assert oracle_fetched.data.twap_confidence == oracle_decode_oracle.twap_confidence
52 | assert (
53 | oracle_fetched.data.has_sufficient_number_of_data_points
54 | == oracle_decode_oracle.has_sufficient_number_of_data_points
55 | )
56 |
57 | assert oracle_unstructured.value == approx(
58 | oracle_fetched.data.price * SWB_PRECISION
59 | )
60 | assert oracle_unstructured.slot == approx(oracle_fetched.data.slot)
61 | assert oracle_unstructured.range == approx(
62 | oracle_fetched.data.confidence * SWB_PRECISION
63 | )
64 |
65 | assert (oracle_unstructured.value // SWB_PRECISION) == oracle_decode_oracle.price
66 | assert oracle_unstructured.slot == oracle_decode_oracle.slot
67 | assert (
68 | oracle_unstructured.range // SWB_PRECISION
69 | ) == oracle_decode_oracle.confidence
70 |
--------------------------------------------------------------------------------
/tests/math/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drift-labs/driftpy/359c640020a0b9f7c33e0dd90630ab8830ec3d07/tests/math/__init__.py
--------------------------------------------------------------------------------
/tests/math/auction.py:
--------------------------------------------------------------------------------
1 | from pytest import mark
2 |
3 | from driftpy.math.auction import derive_oracle_auction_params
4 | from driftpy.types import PositionDirection
5 | from driftpy.constants.numeric_constants import PRICE_PRECISION
6 |
7 |
8 | @mark.asyncio
9 | async def test_drive_oracle_auction_params():
10 | oracle_price = 100 * PRICE_PRECISION
11 | auction_start_price = 90 * PRICE_PRECISION
12 | auction_end_price = 110 * PRICE_PRECISION
13 | limit_price = 120 * PRICE_PRECISION
14 |
15 | oracle_order_params = derive_oracle_auction_params(
16 | PositionDirection.Long(),
17 | oracle_price,
18 | auction_start_price,
19 | auction_end_price,
20 | limit_price,
21 | )
22 |
23 | assert oracle_order_params[0] == -10 * PRICE_PRECISION
24 | assert oracle_order_params[1] == 10 * PRICE_PRECISION
25 | assert oracle_order_params[2] == 20 * PRICE_PRECISION
26 |
27 | oracle_order_params = derive_oracle_auction_params(
28 | PositionDirection.Long(), oracle_price, oracle_price, oracle_price, oracle_price
29 | )
30 |
31 | assert oracle_order_params[0] == 0
32 | assert oracle_order_params[1] == 0
33 | assert oracle_order_params[2] == 1
34 |
35 | oracle_price = 100 * PRICE_PRECISION
36 | auction_start_price = 110 * PRICE_PRECISION
37 | auction_end_price = 90 * PRICE_PRECISION
38 | limit_price = 80 * PRICE_PRECISION
39 |
40 | oracle_order_params = derive_oracle_auction_params(
41 | PositionDirection.Short(),
42 | oracle_price,
43 | auction_start_price,
44 | auction_end_price,
45 | limit_price,
46 | )
47 |
48 | assert oracle_order_params[0] == 10 * PRICE_PRECISION
49 | assert oracle_order_params[1] == -10 * PRICE_PRECISION
50 | assert oracle_order_params[2] == -20 * PRICE_PRECISION
51 |
52 | oracle_order_params = derive_oracle_auction_params(
53 | PositionDirection.Short(),
54 | oracle_price,
55 | oracle_price,
56 | oracle_price,
57 | oracle_price,
58 | )
59 |
60 | assert oracle_order_params[0] == 0
61 | assert oracle_order_params[1] == 0
62 | assert oracle_order_params[2] == -1
63 |
--------------------------------------------------------------------------------
/tests/math/insurance.py:
--------------------------------------------------------------------------------
1 | from pytest import mark
2 |
3 | from driftpy.math.utils import time_remaining_until_update
4 |
5 |
6 | @mark.asyncio
7 | async def test_time_remaining_updates():
8 | now = 1_683_576_852
9 | last_update = 1_683_576_000
10 | period = 3_600
11 |
12 | tr = time_remaining_until_update(now, last_update, period)
13 | assert tr == 2_748
14 |
15 | tr = time_remaining_until_update(now, last_update - period, period)
16 | assert tr == 0
17 |
18 | too_late = last_update - ((period // 3) + 1)
19 | tr = time_remaining_until_update(too_late + 1, too_late, period)
20 | assert tr == 4_800
21 |
22 | tr = time_remaining_until_update(now, last_update + 1, period)
23 | assert tr == 2_748
24 |
25 | tr = time_remaining_until_update(now, last_update - 1, period)
26 | assert tr == 2_748
27 |
--------------------------------------------------------------------------------
/update_idl.sh:
--------------------------------------------------------------------------------
1 | git submodule update --remote --merge --recursive &&
2 | cd protocol-v2/ &&
3 | anchor build &&
4 | cp target/idl/drift.json ../src/driftpy/idl/drift.json
--------------------------------------------------------------------------------