├── .coveragerc ├── .flake8 ├── .github └── workflows │ ├── publish-to-pypi.yml │ └── pull-request.yaml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── design └── diagram.png ├── docker-compose.yml ├── libs ├── __init__.py └── ib_client │ ├── .gitignore │ ├── MANIFEST.in │ ├── README.md │ ├── __init__.py │ ├── ibapi │ ├── __init__.py │ ├── account_summary_tags.py │ ├── client.py │ ├── comm.py │ ├── commission_report.py │ ├── common.py │ ├── connection.py │ ├── contract.py │ ├── decoder.py │ ├── enum_implem.py │ ├── errors.py │ ├── execution.py │ ├── ibapi.pyproj │ ├── message.py │ ├── news.py │ ├── object_implem.py │ ├── order.py │ ├── order_condition.py │ ├── order_state.py │ ├── orderdecoder.py │ ├── reader.py │ ├── scanner.py │ ├── server_versions.py │ ├── softdollartier.py │ ├── tag_value.py │ ├── ticktype.py │ ├── utils.py │ └── wrapper.py │ ├── setup.py │ ├── tests │ ├── __init__.py │ ├── manual.py │ ├── test_account_summary_tags.py │ ├── test_comm.py │ ├── test_enum_implem.py │ ├── test_order_conditions.py │ └── test_utils.py │ └── tox.ini ├── poetry.lock ├── pyproject.toml ├── src └── algotrader │ ├── __init__.py │ ├── assets │ ├── assets_provider.py │ ├── correlation_config.json │ ├── crypto.symbols │ └── sp500.symbols │ ├── calc │ ├── __init__.py │ ├── calculations.py │ └── technicals.py │ ├── cli │ ├── __init__.py │ ├── helpers.py │ ├── main.py │ ├── pipeline.py │ ├── processors.py │ ├── sources.py │ └── strategies.py │ ├── entities │ ├── __init__.py │ ├── attachments │ │ ├── __init__.py │ │ ├── assets_correlation.py │ │ ├── nothing.py │ │ ├── returns.py │ │ ├── technicals.py │ │ ├── technicals_buckets_matcher.py │ │ └── technicals_normalizer.py │ ├── base_dto.py │ ├── bucket.py │ ├── bucketscontainer.py │ ├── candle.py │ ├── candle_attachments.py │ ├── event.py │ ├── generic_candle_attachment.py │ ├── order_direction.py │ ├── serializable.py │ ├── strategy.py │ ├── strategy_signal.py │ └── timespan.py │ ├── examples │ ├── mongo-history-by-buckets.md │ └── pipeline-templates │ │ ├── backtest_history_buckets_backtester.json │ │ ├── backtest_history_similarity_backtester.json │ │ ├── backtest_mongo_source_rsi_strategy.json │ │ ├── backtest_technicals_with_buckets_calculator.json │ │ ├── bins.json │ │ ├── build_daily_binance_loader.json │ │ ├── build_daily_yahoo_loader.json │ │ ├── build_realtime_binance.json │ │ ├── correlation.json │ │ ├── loader_simple_daily_loader.json │ │ ├── loader_simple_returns_calculator.json │ │ ├── loader_simple_technicals_calculator.json │ │ └── loader_technicals_with_buckets_matcher.json │ ├── logger │ └── __init__.py │ ├── main.py │ ├── market │ ├── __init__.py │ ├── async_market_provider.py │ ├── async_query_result.py │ ├── ib_market.py │ ├── market_provider.py │ └── yahoofinance │ │ ├── __init__.py │ │ └── history_provider.py │ ├── pipeline │ ├── __init__.py │ ├── builders │ │ ├── __init__.py │ │ ├── backtest.py │ │ └── loaders.py │ ├── configs │ │ ├── __init__.py │ │ ├── indicator_config.py │ │ └── technical_processor_config.py │ ├── pipeline.py │ ├── processor.py │ ├── processors │ │ ├── __init__.py │ │ ├── assets_correlation.py │ │ ├── candle_cache.py │ │ ├── file_sink.py │ │ ├── returns.py │ │ ├── storage_provider_sink.py │ │ ├── strategy.py │ │ ├── technicals.py │ │ ├── technicals_buckets_matcher.py │ │ ├── technicals_normalizer.py │ │ └── timespan_change.py │ ├── reverse_source.py │ ├── runner.py │ ├── shared_context.py │ ├── source.py │ ├── sources │ │ ├── __init__.py │ │ ├── binance_history.py │ │ ├── binance_realtime.py │ │ ├── ib_history.py │ │ ├── mongodb_source.py │ │ └── yahoo_finance_history.py │ ├── strategies │ │ ├── __init__.py │ │ ├── connors_rsi2.py │ │ ├── history_bucket_compare.py │ │ ├── history_cosine_similarity.py │ │ └── simple_sma.py │ ├── terminator.py │ └── terminators │ │ ├── __init__.py │ │ └── technicals_binner.py │ ├── providers │ ├── __init__.py │ ├── binance.py │ └── ib │ │ ├── __init__.py │ │ ├── ib_interval.py │ │ ├── interactive_brokers_connector.py │ │ └── query_subscription.py │ ├── serialization │ ├── __init__.py │ └── store.py │ ├── storage │ ├── __init__.py │ ├── inmemory_storage.py │ ├── mongodb_storage.py │ └── storage_provider.py │ └── trade │ ├── __init__.py │ ├── signals_executor.py │ ├── simple_sum_signals_executor.py │ └── stdout_signals_executor.py └── tests ├── configs └── correlations.json ├── fakes ├── __init__.py ├── pipeline_validators.py ├── source.py └── strategy_executor.py ├── integration ├── __init__.py ├── test_binance_provider.py ├── test_ib_provider.py ├── test_ib_source.py ├── test_yahoo_provider.py └── test_yahoo_source.py └── unit ├── __init__.py ├── strategies ├── __init__.py ├── test_history_compare.py └── test_simple_sma.py ├── test_asset_correlation.py ├── test_assets_provider.py ├── test_async_query_result.py ├── test_candle_cache.py ├── test_filesink_processor.py ├── test_inmemory_storage.py ├── test_mongo_source.py ├── test_mongodb_sink_processor.py ├── test_mongodb_storage.py ├── test_multiple_pipelines.py ├── test_returns_calculator_processor.py ├── test_reverse_source.py ├── test_serialization.py ├── test_serializations.py ├── test_simple_sum_signals_executor.py ├── test_strategy_processor.py ├── test_technicals_binner_terminator.py ├── test_technicals_processor.py └── test_timespan_change_processor.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | relative_files = true 3 | branch = True 4 | omit = 5 | */src/libs/* 6 | 7 | [report] 8 | omit = 9 | */src/libs/* 10 | 11 | 12 | # Regexes for lines to exclude from consideration 13 | exclude_lines = 14 | # Have to re-enable the standard pragma 15 | pragma: no cover 16 | 17 | # Don't complain about missing debug-only code: 18 | def __repr__ 19 | if self\.debug 20 | 21 | # Don't complain if tests don't hit defensive assertion code: 22 | raise AssertionError 23 | raise NotImplementedError 24 | 25 | # Don't complain if non-runnable code isn't run: 26 | if 0: 27 | if __name__ == .__main__.: 28 | 29 | # Don't complain about abstract methods, they aren't run: 30 | @(abc\.)?abstractmethod 31 | 32 | ignore_errors = True 33 | 34 | [html] 35 | directory = coverage_html_report -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude=build/,run/,libs/,venv/,src/algotrader/pipeline/processors/__init__.py,src/algotrader/pipeline/strategies/__init__.py,src/algotrader/pipeline/sources/__init__.py 3 | max-line-length=160 4 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | on: 3 | push: 4 | tags: 5 | - "v*.*.*" 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3 11 | - name: Build and publish to pypi 12 | uses: JRubics/poetry-publish@v1.17 13 | with: 14 | pypi_token: ${{ secrets.PYPI_API_TOKEN }} 15 | -------------------------------------------------------------------------------- /.github/workflows/pull-request.yaml: -------------------------------------------------------------------------------- 1 | name: Pull Request 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | build: 8 | 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: ["3.11"] 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | 22 | - name: Install Poetry 23 | uses: snok/install-poetry@v1 24 | with: 25 | virtualenvs-create: true 26 | virtualenvs-in-project: true 27 | installer-parallel: true 28 | 29 | - name: Install dependencies 30 | run: poetry install --no-interaction 31 | 32 | - name: Lint with flake8 33 | run: | 34 | poetry run make lint 35 | 36 | - name: Run unit tests 37 | run: | 38 | poetry run make test -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | _trial_temp/ 2 | .idea 3 | .DS_Store 4 | *.pyc 5 | *.log 6 | run/ 7 | .env 8 | private/ 9 | dist/ 10 | build/ 11 | *.egg-info 12 | *.egg 13 | logs/ 14 | .coverage 15 | /coverage.xml 16 | /report.xml 17 | /.ruff_cache/ 18 | /.pytest_cache/ 19 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thank you for considering contributing to algo-trader! 4 | 5 | ## How to start 6 | If you encounter a bug you wish to fix, or have a feature request, it is best to create an issue with the relevant 7 | information, so we can talk about the best course of action. 8 | Next step would be to fork the repo and create a bug/feature branch. 9 | For better visibility, it would be best if the branch name will have the issue number as its prefix. for example, `123-fix-that`. 10 | 11 | ## Setup your local environment 12 | After forking and branching, make sure your local environment can run and pass the tests. 13 | Running the test can be done with the provided running scripts in the [scripts](scripts) folder. 14 | 15 | ## What needs to be done 16 | All open issues and feature requests should be listed in the issues tab. 17 | If you'd like to work on a new feature, check out the [enhancement](https://github.com/idanya/algo-trader/labels/enhancement) label. 18 | For fixing existing bugs, check out the [bugs](https://github.com/idanya/algo-trader/labels/bug) label. 19 | 20 | Please make sure the feature/bug you are working on is confirmed and approved in the comments and/or added labels. 21 | 22 | ## Make sure new code has test coverage 23 | Please make sure all new code has the relevant tests to cover its logic. In cases of bug fixes, make sure to change the relevant tests to reflect the new behavior. 24 | If you fixed a bug and no test needed to be change, it is a good indicator that a test is missing and should be added. 25 | 26 | ## Style and convention 27 | Every one of us has its own style and way to write code. That's Ok and great. 28 | With that in mind, please make an effort to stay inline with the repo conventions and design so the code will be easily read and understood. 29 | 30 | ## Make a Pull Request 31 | When you think your code is ready, which means: 32 | 33 | 1. It's fixing the issue it meant to fix OR adding the functionality needed 34 | 2. The relevant existing/new tests are checking the new behavior and passing 35 | 3. Style and convention is inline with the repo "spirit" 36 | 37 | Create a PR and reference the issue it solves in the description. 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Idan Yael 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test test-integration lint 2 | SHELL = /bin/bash 3 | 4 | test: 5 | pytest tests/unit --cov --cov-report term --cov-report xml:coverage.xml --junit-xml=report.xml 6 | 7 | test-integration: 8 | pytest tests/integration --cov --cov-report term --cov-report xml:coverage.xml --junit-xml=report.xml 9 | 10 | lint: 11 | ruff ./src/ ./tests/ 12 | black ./src/ ./tests/ --check 13 | #pyright ./src/ ./tests/ 14 | 15 | reformat: 16 | black ./src/ ./tests/ 17 | ruff ./src/ ./tests/ --fix 18 | -------------------------------------------------------------------------------- /design/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/design/diagram.png -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | # Use root/example as user/password credentials 2 | version: '3.1' 3 | 4 | services: 5 | 6 | mongo: 7 | image: mongo 8 | restart: always 9 | ports: 10 | - 27017:27017 11 | environment: 12 | MONGO_INITDB_ROOT_USERNAME: root 13 | MONGO_INITDB_ROOT_PASSWORD: root 14 | 15 | mongo-express: 16 | image: mongo-express 17 | restart: always 18 | ports: 19 | - 8081:8081 20 | environment: 21 | ME_CONFIG_BASICAUTH_USERNAME: root 22 | ME_CONFIG_BASICAUTH_PASSWORD: root 23 | ME_CONFIG_MONGODB_ADMINUSERNAME: root 24 | ME_CONFIG_MONGODB_ADMINPASSWORD: root 25 | ME_CONFIG_MONGODB_URL: mongodb://root:root@mongo:27017/ -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/libs/__init__.py -------------------------------------------------------------------------------- /libs/ib_client/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | tags 3 | __pycache__/* 4 | log/* 5 | *.log 6 | log.* 7 | core 8 | *.xml 9 | MANIFEST 10 | dist 11 | build 12 | .idea 13 | *.egg-info 14 | /.tox/ 15 | -------------------------------------------------------------------------------- /libs/ib_client/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include ibapi/*.py 2 | include README.md 3 | -------------------------------------------------------------------------------- /libs/ib_client/README.md: -------------------------------------------------------------------------------- 1 | A couple of things/definitions/conventions: 2 | * a *low level message* is some data prefixed with its size 3 | * a *high level message* is a list of fields separated by the NULL character; the fields are all strings; the message ID is the first field, the come others whose number and semantics depend on the message itself 4 | * a *request* is a message from client to TWS/IBGW (IB Gateway) 5 | * an *answer* is a message from TWS/IBGW to client 6 | 7 | 8 | How the code is organized: 9 | * *comm* module: has tools that know how to handle (eg: encode/decode) low and high level messages 10 | * *Connection*: glorified socket 11 | * *Reader*: thread that uses Connection to read packets, transform to low level messages and put in a Queue 12 | * *Decoder*: knows how to take a low level message and decode into high level message 13 | * *Client*: 14 | + knows to send requests 15 | + has the message loop which takes low level messages from Queue and uses Decoder to tranform into high level message with which it then calls the corresponding Wrapper method 16 | * *Wrapper*: class that needs to be subclassed by the user so that it can get the incoming messages 17 | 18 | 19 | The info/data flow is: 20 | 21 | * receiving: 22 | + *Connection.recv_msg()* (which is essentially a socket) receives the packets 23 | - uses *Connection._recv_all_msgs()* which tries to combine smaller packets into bigger ones based on some trivial heuristic 24 | + *Reader.run()* uses *Connection.recv_msg()* to get a packet and then uses *comm.read_msg()* to try to make it a low level message. If that can't be done yet (size prefix says so) then it waits for more packets 25 | + if a full low level message is received then it is placed in the Queue (remember this is a standalone thread) 26 | + the main thread runs the *Client.run()* loop which: 27 | - gets a low level message from Queue 28 | - uses *comm.py* to translate into high level message (fields) 29 | - uses *Decoder.interpret()* to act based on that message 30 | + *Decoder.interpret()* will translate the fields into function parameters of the correct type and call with the correct/corresponding method of *Wrapper* class 31 | 32 | * sending: 33 | + *Client* class has methods that implement the _requests_. The user will call those request methods with the needed parameters and *Client* will send them to the TWS/IBGW. 34 | 35 | 36 | Implementation notes: 37 | 38 | * the *Decoder* has two ways of handling a message (esentially decoding the fields) 39 | + some message very neatly map to a function call; meaning that the number of fields and order are the same as the method parameters. For example: Wrapper.tickSize(). In this case a simple mapping is made between the incoming msg id and the Wrapper method: 40 | 41 | IN.TICK_SIZE: HandleInfo(wrap=Wrapper.tickSize), 42 | 43 | + other messages are more complex, depend on version number heavily or need field massaging. In this case the incoming message id is mapped to a processing function that will do all that and call the Wrapper method at the end. For example: 44 | 45 | IN.TICK_PRICE: HandleInfo(proc=processTickPriceMsg), 46 | 47 | 48 | Instalation notes: 49 | 50 | * you can use this to build a source distribution 51 | 52 | python3 setup.py sdist 53 | 54 | * you can use this to build a wheel 55 | 56 | python3 setup.py bdist_wheel 57 | 58 | * you can use this to install the wheel 59 | 60 | python3 -m pip install --user --upgrade dist/ibapi-9.75.1-py3-none-any.whl 61 | 62 | 63 | -------------------------------------------------------------------------------- /libs/ib_client/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/libs/ib_client/__init__.py -------------------------------------------------------------------------------- /libs/ib_client/ibapi/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | """ Package implementing the Python API for the TWS/IB Gateway """ 7 | 8 | VERSION = { 9 | 'major': 9, 10 | 'minor': 76, 11 | 'micro': 1} 12 | 13 | 14 | def get_version_string(): 15 | version = '{major}.{minor}.{micro}'.format(**VERSION) 16 | return version 17 | 18 | __version__ = get_version_string() 19 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/account_summary_tags.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | class AccountSummaryTags: 7 | AccountType = "AccountType" 8 | NetLiquidation = "NetLiquidation" 9 | TotalCashValue = "TotalCashValue" 10 | SettledCash = "SettledCash" 11 | AccruedCash = "AccruedCash" 12 | BuyingPower = "BuyingPower" 13 | EquityWithLoanValue = "EquityWithLoanValue" 14 | PreviousEquityWithLoanValue = "PreviousEquityWithLoanValue" 15 | GrossPositionValue = "GrossPositionValue" 16 | ReqTEquity = "ReqTEquity" 17 | ReqTMargin = "ReqTMargin" 18 | SMA = "SMA" 19 | InitMarginReq = "InitMarginReq" 20 | MaintMarginReq = "MaintMarginReq" 21 | AvailableFunds = "AvailableFunds" 22 | ExcessLiquidity = "ExcessLiquidity" 23 | Cushion = "Cushion" 24 | FullInitMarginReq = "FullInitMarginReq" 25 | FullMaintMarginReq = "FullMaintMarginReq" 26 | FullAvailableFunds = "FullAvailableFunds" 27 | FullExcessLiquidity = "FullExcessLiquidity" 28 | LookAheadNextChange = "LookAheadNextChange" 29 | LookAheadInitMarginReq = "LookAheadInitMarginReq" 30 | LookAheadMaintMarginReq = "LookAheadMaintMarginReq" 31 | LookAheadAvailableFunds = "LookAheadAvailableFunds" 32 | LookAheadExcessLiquidity = "LookAheadExcessLiquidity" 33 | HighestSeverity = "HighestSeverity" 34 | DayTradesRemaining = "DayTradesRemaining" 35 | Leverage = "Leverage" 36 | 37 | AllTags = ",".join((AccountType, NetLiquidation, TotalCashValue, 38 | SettledCash, AccruedCash, BuyingPower, EquityWithLoanValue, 39 | PreviousEquityWithLoanValue, GrossPositionValue, ReqTEquity, 40 | ReqTMargin, SMA, InitMarginReq, MaintMarginReq, AvailableFunds, 41 | ExcessLiquidity , Cushion, FullInitMarginReq, FullMaintMarginReq, 42 | FullAvailableFunds, FullExcessLiquidity, 43 | LookAheadNextChange, LookAheadInitMarginReq, LookAheadMaintMarginReq, 44 | LookAheadAvailableFunds, LookAheadExcessLiquidity, HighestSeverity, 45 | DayTradesRemaining, Leverage)) 46 | 47 | 48 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | 7 | """ 8 | This module has tools for implementing the IB low level messaging. 9 | """ 10 | 11 | 12 | import struct 13 | import logging 14 | 15 | from ibapi.common import UNSET_INTEGER, UNSET_DOUBLE 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def make_msg(text) -> bytes: 21 | """ adds the length prefix """ 22 | msg = struct.pack("!I%ds" % len(text), len(text), str.encode(text)) 23 | return msg 24 | 25 | 26 | def make_field(val) -> str: 27 | """ adds the NULL string terminator """ 28 | 29 | if val is None: 30 | raise ValueError("Cannot send None to TWS") 31 | 32 | # bool type is encoded as int 33 | if type(val) is bool: 34 | val = int(val) 35 | 36 | field = str(val) + '\0' 37 | return field 38 | 39 | 40 | def make_field_handle_empty(val) -> str: 41 | 42 | if val is None: 43 | raise ValueError("Cannot send None to TWS") 44 | 45 | if UNSET_INTEGER == val or UNSET_DOUBLE == val: 46 | val = "" 47 | 48 | return make_field(val) 49 | 50 | 51 | def read_msg(buf:bytes) -> tuple: 52 | """ first the size prefix and then the corresponding msg payload """ 53 | if len(buf) < 4: 54 | return (0, "", buf) 55 | size = struct.unpack("!I", buf[0:4])[0] 56 | logger.debug("read_msg: size: %d", size) 57 | if len(buf) - 4 >= size: 58 | text = struct.unpack("!%ds" % size, buf[4:4+size])[0] 59 | return (size, text, buf[4+size:]) 60 | else: 61 | return (size, "", buf) 62 | 63 | 64 | def read_fields(buf:bytes) -> tuple: 65 | 66 | if isinstance(buf, str): 67 | buf = buf.encode() 68 | 69 | """ msg payload is made of fields terminated/separated by NULL chars """ 70 | fields = buf.split(b"\0") 71 | 72 | return tuple(fields[0:-1]) #last one is empty; this may slow dow things though, TODO 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/commission_report.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | from ibapi.object_implem import Object 7 | from ibapi import utils 8 | 9 | class CommissionReport(Object): 10 | 11 | def __init__(self): 12 | self.execId = "" 13 | self.commission = 0. 14 | self.currency = "" 15 | self.realizedPNL = 0. 16 | self.yield_ = 0. 17 | self.yieldRedemptionDate = 0 # YYYYMMDD format 18 | 19 | def __str__(self): 20 | return "ExecId: %s, Commission: %f, Currency: %s, RealizedPnL: %s, Yield: %s, YieldRedemptionDate: %d" % (self.execId, self.commission, 21 | self.currency, utils.floatToStr(self.realizedPNL), utils.floatToStr(self.yield_), self.yieldRedemptionDate) 22 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/connection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | 7 | """ 8 | Just a thin wrapper around a socket. 9 | It allows us to keep some other info along with it. 10 | """ 11 | 12 | 13 | import socket 14 | import threading 15 | import logging 16 | 17 | from ibapi.common import * # @UnusedWildImport 18 | from ibapi.errors import * # @UnusedWildImport 19 | 20 | 21 | #TODO: support SSL !! 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class Connection: 27 | def __init__(self, host, port): 28 | self.host = host 29 | self.port = port 30 | self.socket = None 31 | self.wrapper = None 32 | self.lock = threading.Lock() 33 | 34 | 35 | def connect(self): 36 | try: 37 | self.socket = socket.socket() 38 | #TODO: list the exceptions you want to catch 39 | except socket.error: 40 | if self.wrapper: 41 | self.wrapper.error(NO_VALID_ID, FAIL_CREATE_SOCK.code(), FAIL_CREATE_SOCK.msg()) 42 | 43 | try: 44 | self.socket.connect((self.host, self.port)) 45 | except socket.error: 46 | if self.wrapper: 47 | self.wrapper.error(NO_VALID_ID, CONNECT_FAIL.code(), CONNECT_FAIL.msg()) 48 | 49 | self.socket.settimeout(1) #non-blocking 50 | 51 | 52 | def disconnect(self): 53 | self.lock.acquire() 54 | try: 55 | if self.socket is not None: 56 | logger.debug("disconnecting") 57 | self.socket.close() 58 | self.socket = None 59 | logger.debug("disconnected") 60 | if self.wrapper: 61 | self.wrapper.connectionClosed() 62 | finally: 63 | self.lock.release() 64 | 65 | 66 | def isConnected(self): 67 | return self.socket is not None 68 | 69 | 70 | def sendMsg(self, msg): 71 | 72 | logger.debug("acquiring lock") 73 | self.lock.acquire() 74 | logger.debug("acquired lock") 75 | if not self.isConnected(): 76 | logger.debug("sendMsg attempted while not connected, releasing lock") 77 | self.lock.release() 78 | return 0 79 | try: 80 | nSent = self.socket.send(msg) 81 | except socket.error: 82 | logger.debug("exception from sendMsg %s", sys.exc_info()) 83 | raise 84 | finally: 85 | logger.debug("releasing lock") 86 | self.lock.release() 87 | logger.debug("release lock") 88 | 89 | logger.debug("sendMsg: sent: %d", nSent) 90 | 91 | return nSent 92 | 93 | 94 | def recvMsg(self): 95 | if not self.isConnected(): 96 | logger.debug("recvMsg attempted while not connected, releasing lock") 97 | return b"" 98 | try: 99 | buf = self._recvAllMsg() 100 | # receiving 0 bytes outside a timeout means the connection is either 101 | # closed or broken 102 | if len(buf) == 0: 103 | logger.debug("socket either closed or broken, disconnecting") 104 | self.disconnect() 105 | except socket.timeout: 106 | logger.debug("socket timeout from recvMsg %s", sys.exc_info()) 107 | buf = b"" 108 | else: 109 | pass 110 | 111 | return buf 112 | 113 | 114 | def _recvAllMsg(self): 115 | cont = True 116 | allbuf = b"" 117 | 118 | while cont and self.socket is not None: 119 | buf = self.socket.recv(4096) 120 | allbuf += buf 121 | logger.debug("len %d raw:%s|", len(buf), buf) 122 | 123 | if len(buf) < 4096: 124 | cont = False 125 | 126 | return allbuf 127 | 128 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/enum_implem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | 7 | """ 8 | Simple enum implementation 9 | """ 10 | 11 | 12 | class Enum: 13 | def __init__(self, *args): 14 | self.idx2name = {} 15 | for (idx, name) in enumerate(args): 16 | setattr(self, name, idx) 17 | self.idx2name[idx] = name 18 | 19 | def to_str(self, idx): 20 | return self.idx2name.get(idx, "NOTFOUND") 21 | 22 | 23 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/errors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | 7 | """ 8 | This is the interface that will need to be overloaded by the customer so 9 | that his/her code can receive info from the TWS/IBGW. 10 | """ 11 | 12 | 13 | class CodeMsgPair: 14 | def __init__(self, code, msg): 15 | self.errorCode = code 16 | self.errorMsg = msg 17 | 18 | def code(self): 19 | return self.errorCode 20 | 21 | def msg(self): 22 | return self.errorMsg 23 | 24 | 25 | ALREADY_CONNECTED = CodeMsgPair(501, "Already connected.") 26 | CONNECT_FAIL = CodeMsgPair(502, 27 | """Couldn't connect to TWS. Confirm that \"Enable ActiveX and Socket EClients\" 28 | is enabled and connection port is the same as \"Socket Port\" on the 29 | TWS \"Edit->Global Configuration...->API->Settings\" menu. Live Trading ports: 30 | TWS: 7496; IB Gateway: 4001. Simulated Trading ports for new installations 31 | of version 954.1 or newer: TWS: 7497; IB Gateway: 4002""") 32 | UPDATE_TWS = CodeMsgPair(503, "The TWS is out of date and must be upgraded.") 33 | NOT_CONNECTED = CodeMsgPair(504, "Not connected") 34 | UNKNOWN_ID = CodeMsgPair(505, "Fatal Error: Unknown message id.") 35 | UNSUPPORTED_VERSION = CodeMsgPair(506, "Unsupported version") 36 | BAD_LENGTH = CodeMsgPair(507, "Bad message length") 37 | BAD_MESSAGE = CodeMsgPair(508, "Bad message") 38 | SOCKET_EXCEPTION = CodeMsgPair(509, "Exception caught while reading socket - ") 39 | FAIL_CREATE_SOCK = CodeMsgPair(520, "Failed to create socket") 40 | SSL_FAIL = CodeMsgPair(530, "SSL specific error: ") 41 | 42 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/execution.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | 7 | 8 | from ibapi.object_implem import Object 9 | 10 | class Execution(Object): 11 | 12 | def __init__(self): 13 | self.execId = "" 14 | self.time = "" 15 | self.acctNumber = "" 16 | self.exchange = "" 17 | self.side = "" 18 | self.shares = 0. 19 | self.price = 0. 20 | self.permId = 0 21 | self.clientId = 0 22 | self.orderId = 0 23 | self.liquidation = 0 24 | self.cumQty = 0. 25 | self.avgPrice = 0. 26 | self.orderRef = "" 27 | self.evRule = "" 28 | self.evMultiplier = 0. 29 | self.modelCode = "" 30 | self.lastLiquidity = 0 31 | 32 | def __str__(self): 33 | return "ExecId: %s, Time: %s, Account: %s, Exchange: %s, Side: %s, Shares: %f, Price: %f, PermId: %d, " \ 34 | "ClientId: %d, OrderId: %d, Liquidation: %d, CumQty: %f, AvgPrice: %f, OrderRef: %s, EvRule: %s, " \ 35 | "EvMultiplier: %f, ModelCode: %s, LastLiquidity: %d" % (self.execId, self.time, self.acctNumber, 36 | self.exchange, self.side, self.shares, self.price, self.permId, self.clientId, self.orderId, self.liquidation, 37 | self.cumQty, self.avgPrice, self.orderRef, self.evRule, self.evMultiplier, self.modelCode, self.lastLiquidity) 38 | 39 | class ExecutionFilter(Object): 40 | 41 | # Filter fields 42 | def __init__(self): 43 | self.clientId = 0 44 | self.acctCode = "" 45 | self.time = "" 46 | self.symbol = "" 47 | self.secType = "" 48 | self.exchange = "" 49 | self.side = "" 50 | 51 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/ibapi.pyproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | Debug 5 | 2.0 6 | {aa7df1c2-6d30-4556-b6d5-a188f972bbdd} 7 | 8 | account_summary_tags.py 9 | 10 | . 11 | . 12 | {888888a0-9f3d-457c-b088-3a5042f75d52} 13 | Standard Python launcher 14 | 15 | 16 | 17 | 18 | 19 | 10.0 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/news.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | # TWS New Bulletins constants 7 | NEWS_MSG = 1 # standard IB news bulleting message 8 | EXCHANGE_AVAIL_MSG = 2 # control message specifing that an exchange is available for trading 9 | EXCHANGE_UNAVAIL_MSG = 3 # control message specifing that an exchange is unavailable for trading 10 | 11 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/object_implem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | class Object(object): 7 | 8 | def __str__(self): 9 | return "Object" 10 | 11 | def __repr__(self): 12 | return str(id(self)) + ": " + self.__str__() 13 | 14 | 15 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/order_state.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | from ibapi.common import UNSET_DOUBLE 7 | 8 | 9 | class OrderState: 10 | 11 | def __init__(self): 12 | self.status= "" 13 | 14 | self.initMarginBefore= "" 15 | self.maintMarginBefore= "" 16 | self.equityWithLoanBefore= "" 17 | self.initMarginChange= "" 18 | self.maintMarginChange= "" 19 | self.equityWithLoanChange= "" 20 | self.initMarginAfter= "" 21 | self.maintMarginAfter= "" 22 | self.equityWithLoanAfter= "" 23 | 24 | self.commission = UNSET_DOUBLE # type: float 25 | self.minCommission = UNSET_DOUBLE # type: float 26 | self.maxCommission = UNSET_DOUBLE # type: float 27 | self.commissionCurrency = "" 28 | self.warningText = "" 29 | self.completedTime = "" 30 | self.completedStatus = "" 31 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/reader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | 7 | """ 8 | The EReader runs in a separate threads and is responsible for receiving the 9 | incoming messages. 10 | It will read the packets from the wire, use the low level IB messaging to 11 | remove the size prefix and put the rest in a Queue. 12 | """ 13 | 14 | import logging 15 | from threading import Thread 16 | 17 | from ibapi import comm 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class EReader(Thread): 24 | def __init__(self, conn, msg_queue): 25 | super().__init__() 26 | self.conn = conn 27 | self.msg_queue = msg_queue 28 | 29 | def run(self): 30 | try: 31 | buf = b"" 32 | while self.conn.isConnected(): 33 | 34 | data = self.conn.recvMsg() 35 | logger.debug("reader loop, recvd size %d", len(data)) 36 | buf += data 37 | 38 | while len(buf) > 0: 39 | (size, msg, buf) = comm.read_msg(buf) 40 | #logger.debug("resp %s", buf.decode('ascii')) 41 | logger.debug("size:%d msg.size:%d msg:|%s| buf:%s|", size, 42 | len(msg), buf, "|") 43 | 44 | if msg: 45 | self.msg_queue.put(msg) 46 | else: 47 | logger.debug("more incoming packet(s) are needed ") 48 | break 49 | 50 | logger.debug("EReader thread finished") 51 | except: 52 | logger.exception('unhandled exception in EReader thread') 53 | 54 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/scanner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | 7 | from ibapi.object_implem import Object 8 | from ibapi.common import UNSET_INTEGER, UNSET_DOUBLE 9 | 10 | 11 | class ScanData(Object): 12 | def __init__(self, contract = None, rank = 0, distance = "", benchmark = "", projection = "", legsStr = ""): 13 | self.contract = contract 14 | self.rank = rank 15 | self.distance = distance 16 | self.benchmark = benchmark 17 | self.projection = projection 18 | self.legsStr = legsStr 19 | 20 | def __str__(self): 21 | return "Rank: %d, Symbol: %s, SecType: %s, Currency: %s, Distance: %s, Benchmark: %s, Projection: %s, Legs String: %s" % (self.rank, 22 | self.contract.symbol, self.contract.secType, self.contract.currency, self.distance, 23 | self.benchmark, self.projection, self.legsStr) 24 | 25 | NO_ROW_NUMBER_SPECIFIED = -1 26 | 27 | class ScannerSubscription(Object): 28 | 29 | def __init__(self): 30 | self.numberOfRows = NO_ROW_NUMBER_SPECIFIED 31 | self.instrument = "" 32 | self.locationCode = "" 33 | self.scanCode = "" 34 | self.abovePrice = UNSET_DOUBLE 35 | self.belowPrice = UNSET_DOUBLE 36 | self.aboveVolume = UNSET_INTEGER 37 | self.marketCapAbove = UNSET_DOUBLE 38 | self.marketCapBelow = UNSET_DOUBLE 39 | self.moodyRatingAbove = "" 40 | self.moodyRatingBelow = "" 41 | self.spRatingAbove = "" 42 | self.spRatingBelow = "" 43 | self.maturityDateAbove = "" 44 | self.maturityDateBelow = "" 45 | self.couponRateAbove = UNSET_DOUBLE 46 | self.couponRateBelow = UNSET_DOUBLE 47 | self.excludeConvertible = False 48 | self.averageOptionVolumeAbove = UNSET_INTEGER 49 | self.scannerSettingPairs = "" 50 | self.stockTypeFilter = "" 51 | 52 | 53 | def __str__(self): 54 | s = "Instrument: %s, LocationCode: %s, ScanCode: %s" % (self.instrument, self.locationCode, self.scanCode) 55 | 56 | return s 57 | 58 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/softdollartier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | 7 | from ibapi.object_implem import Object 8 | 9 | 10 | class SoftDollarTier(Object): 11 | def __init__(self, name = "", val = "", displayName = ""): 12 | self.name = name 13 | self.val = val 14 | self.displayName = displayName 15 | 16 | def __str__(self): 17 | return "Name: %s, Value: %s, DisplayName: %s" % (self.name, self.val, self.displayName) 18 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/tag_value.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | """ 7 | Simple class mapping a tag to a value. Both of them are strings. 8 | They are used in a list to convey extra info with the requests. 9 | """ 10 | 11 | from ibapi.object_implem import Object 12 | 13 | 14 | class TagValue(Object): 15 | def __init__(self, tag:str=None, value:str=None): 16 | self.tag = str(tag) 17 | self.value = str(value) 18 | 19 | def __str__(self): 20 | # this is not only used for Python dump but when encoding to send 21 | # so don't change it lightly ! 22 | return "%s=%s;" % (self.tag, self.value) 23 | 24 | 25 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/ticktype.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | 7 | """ 8 | TickType type 9 | """ 10 | 11 | from ibapi.enum_implem import Enum 12 | 13 | 14 | # TickType 15 | TickType = int 16 | TickTypeEnum = Enum("BID_SIZE", 17 | "BID", 18 | "ASK", 19 | "ASK_SIZE", 20 | "LAST", 21 | "LAST_SIZE", 22 | "HIGH", 23 | "LOW", 24 | "VOLUME", 25 | "CLOSE", 26 | "BID_OPTION_COMPUTATION", 27 | "ASK_OPTION_COMPUTATION", 28 | "LAST_OPTION_COMPUTATION", 29 | "MODEL_OPTION", 30 | "OPEN", 31 | "LOW_13_WEEK", 32 | "HIGH_13_WEEK", 33 | "LOW_26_WEEK", 34 | "HIGH_26_WEEK", 35 | "LOW_52_WEEK", 36 | "HIGH_52_WEEK", 37 | "AVG_VOLUME", 38 | "OPEN_INTEREST", 39 | "OPTION_HISTORICAL_VOL", 40 | "OPTION_IMPLIED_VOL", 41 | "OPTION_BID_EXCH", 42 | "OPTION_ASK_EXCH", 43 | "OPTION_CALL_OPEN_INTEREST", 44 | "OPTION_PUT_OPEN_INTEREST", 45 | "OPTION_CALL_VOLUME", 46 | "OPTION_PUT_VOLUME", 47 | "INDEX_FUTURE_PREMIUM", 48 | "BID_EXCH", 49 | "ASK_EXCH", 50 | "AUCTION_VOLUME", 51 | "AUCTION_PRICE", 52 | "AUCTION_IMBALANCE", 53 | "MARK_PRICE", 54 | "BID_EFP_COMPUTATION", 55 | "ASK_EFP_COMPUTATION", 56 | "LAST_EFP_COMPUTATION", 57 | "OPEN_EFP_COMPUTATION", 58 | "HIGH_EFP_COMPUTATION", 59 | "LOW_EFP_COMPUTATION", 60 | "CLOSE_EFP_COMPUTATION", 61 | "LAST_TIMESTAMP", 62 | "SHORTABLE", 63 | "FUNDAMENTAL_RATIOS", 64 | "RT_VOLUME", 65 | "HALTED", 66 | "BID_YIELD", 67 | "ASK_YIELD", 68 | "LAST_YIELD", 69 | "CUST_OPTION_COMPUTATION", 70 | "TRADE_COUNT", 71 | "TRADE_RATE", 72 | "VOLUME_RATE", 73 | "LAST_RTH_TRADE", 74 | "RT_HISTORICAL_VOL", 75 | "IB_DIVIDENDS", 76 | "BOND_FACTOR_MULTIPLIER", 77 | "REGULATORY_IMBALANCE", 78 | "NEWS_TICK", 79 | "SHORT_TERM_VOLUME_3_MIN", 80 | "SHORT_TERM_VOLUME_5_MIN", 81 | "SHORT_TERM_VOLUME_10_MIN", 82 | "DELAYED_BID", 83 | "DELAYED_ASK", 84 | "DELAYED_LAST", 85 | "DELAYED_BID_SIZE", 86 | "DELAYED_ASK_SIZE", 87 | "DELAYED_LAST_SIZE", 88 | "DELAYED_HIGH", 89 | "DELAYED_LOW", 90 | "DELAYED_VOLUME", 91 | "DELAYED_CLOSE", 92 | "DELAYED_OPEN", 93 | "RT_TRD_VOLUME", 94 | "CREDITMAN_MARK_PRICE", 95 | "CREDITMAN_SLOW_MARK_PRICE", 96 | "DELAYED_BID_OPTION", 97 | "DELAYED_ASK_OPTION", 98 | "DELAYED_LAST_OPTION", 99 | "DELAYED_MODEL_OPTION", 100 | "LAST_EXCH", 101 | "LAST_REG_TIME", 102 | "FUTURES_OPEN_INTEREST", 103 | "AVG_OPT_VOLUME", 104 | "DELAYED_LAST_TIMESTAMP", 105 | "SHORTABLE_SHARES", 106 | "NOT_SET") 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /libs/ib_client/ibapi/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | 7 | """ 8 | Collection of misc tools 9 | """ 10 | 11 | 12 | import sys 13 | import logging 14 | import inspect 15 | 16 | from ibapi.common import UNSET_INTEGER, UNSET_DOUBLE, UNSET_LONG 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | # I use this just to visually emphasize it's a wrapper overriden method 23 | def iswrapper(fn): 24 | return fn 25 | 26 | 27 | class BadMessage(Exception): 28 | def __init__(self, text): 29 | self.text = text 30 | 31 | 32 | class LogFunction(object): 33 | def __init__(self, text, logLevel): 34 | self.text = text 35 | self.logLevel = logLevel 36 | 37 | def __call__(self, fn): 38 | def newFn(origSelf, *args, **kwargs): 39 | if logger.getLogger().isEnabledFor(self.logLevel): 40 | argNames = [argName for argName in inspect.getfullargspec(fn)[0] if argName != 'self'] 41 | logger.log(self.logLevel, 42 | "{} {} {} kw:{}".format(self.text, fn.__name__, 43 | [nameNarg for nameNarg in zip(argNames, args) if nameNarg[1] is not origSelf], kwargs)) 44 | fn(origSelf, *args) 45 | return newFn 46 | 47 | 48 | def current_fn_name(parent_idx = 0): 49 | #depth is 1 bc this is already a fn, so we need the caller 50 | return sys._getframe(1 + parent_idx).f_code.co_name 51 | 52 | 53 | def setattr_log(self, var_name, var_value): 54 | #import code; code.interact(local=locals()) 55 | logger.debug("%s %s %s=|%s|", self.__class__, id(self), var_name, var_value) 56 | super(self.__class__, self).__setattr__(var_name, var_value) 57 | 58 | 59 | SHOW_UNSET = True 60 | def decode(the_type, fields, show_unset = False): 61 | try: 62 | s = next(fields) 63 | except StopIteration: 64 | raise BadMessage("no more fields") 65 | 66 | logger.debug("decode %s %s", the_type, s) 67 | 68 | if the_type is str: 69 | if type(s) is str: 70 | return s 71 | elif type(s) is bytes: 72 | return s.decode(errors='backslashreplace') 73 | else: 74 | raise TypeError("unsupported incoming type " + type(s) + " for desired type 'str") 75 | 76 | orig_type = the_type 77 | if the_type is bool: 78 | the_type = int 79 | 80 | if show_unset: 81 | if s is None or len(s) == 0: 82 | if the_type is float: 83 | n = UNSET_DOUBLE 84 | elif the_type is int: 85 | n = UNSET_INTEGER 86 | else: 87 | raise TypeError("unsupported desired type for empty value" + the_type) 88 | else: 89 | n = the_type(s) 90 | else: 91 | n = the_type(s or 0) 92 | 93 | if orig_type is bool: 94 | n = False if n == 0 else True 95 | 96 | return n 97 | 98 | 99 | 100 | def ExerciseStaticMethods(klass): 101 | 102 | import types 103 | #import code; code.interact(local=dict(globals(), **locals())) 104 | for (_, var) in inspect.getmembers(klass): 105 | #print(name, var, type(var)) 106 | if type(var) == types.FunctionType: 107 | print("Exercising: %s:" % var) 108 | print(var()) 109 | print() 110 | 111 | def floatToStr(val): 112 | return str(val) if val != UNSET_DOUBLE else ""; 113 | 114 | def longToStr(val): 115 | return str(val) if val != UNSET_LONG else ""; 116 | 117 | 118 | -------------------------------------------------------------------------------- /libs/ib_client/setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | #from distutils.core import setup 6 | from setuptools import setup 7 | from ibapi import get_version_string 8 | 9 | import sys 10 | 11 | if sys.version_info < (3,1): 12 | sys.exit("Only Python 3.1 and greater is supported") 13 | 14 | setup( 15 | name='ibapi', 16 | version=get_version_string(), 17 | packages=['ibapi'], 18 | url='https://interactivebrokers.github.io/tws-api', 19 | license='IB API Non-Commercial License or the IB API Commercial License', 20 | author='IBG LLC', 21 | author_email='dnastase@interactivebrokers.com', 22 | description='Python IB API' 23 | ) 24 | -------------------------------------------------------------------------------- /libs/ib_client/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/libs/ib_client/tests/__init__.py -------------------------------------------------------------------------------- /libs/ib_client/tests/test_account_summary_tags.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | import unittest 7 | from ibapi.account_summary_tags import AccountSummaryTags 8 | 9 | 10 | class AccountSummaryTagsTestCase(unittest.TestCase): 11 | def setUp(self): 12 | pass 13 | 14 | 15 | def tearDown(self): 16 | pass 17 | 18 | 19 | def test_all_tags(self): 20 | print(AccountSummaryTags.AllTags) 21 | 22 | 23 | if "__main__" == __name__: 24 | unittest.main() 25 | 26 | -------------------------------------------------------------------------------- /libs/ib_client/tests/test_comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | import unittest 7 | import struct 8 | from ibapi import comm 9 | 10 | 11 | class CommTestCase(unittest.TestCase): 12 | def setUp(self): 13 | pass 14 | 15 | 16 | def tearDown(self): 17 | pass 18 | 19 | 20 | def test_make_msg(self): 21 | text = "ABCD" 22 | msg = comm.make_msg(text) 23 | 24 | size = struct.unpack("!I", msg[0:4])[0] 25 | 26 | self.assertEqual(size, len(text), "msg size not good") 27 | self.assertEqual(msg[4:].decode(), text, "msg payload not good") 28 | 29 | 30 | def test_make_field(self): 31 | text = "ABCD" 32 | field = comm.make_field(text) 33 | 34 | self.assertEqual(field[-1], "\0", "terminator not good") 35 | self.assertEqual(len(field[0:-1]), len(text), "payload size not good") 36 | self.assertEqual(field[0:-1], text, "payload not good") 37 | 38 | 39 | def test_read_msg(self): 40 | text = "ABCD" 41 | msg = comm.make_msg(text) 42 | 43 | (size, text2, rest) = comm.read_msg(msg) 44 | 45 | self.assertEqual(size, len(text), "msg size not good") 46 | self.assertEqual(text2.decode(), text, "msg payload not good") 47 | self.assertEqual(len(rest), 0, "there should be no remainder msg") 48 | 49 | 50 | def test_readFields(self): 51 | text1 = "ABCD" 52 | text2 = "123" 53 | 54 | msg = comm.make_msg(comm.make_field(text1) + comm.make_field(text2)) 55 | 56 | (size, text, rest) = comm.read_msg(msg) 57 | fields = comm.read_fields(text) 58 | 59 | self.assertEqual(len(fields), 2, "incorrect number of fields") 60 | self.assertEqual(fields[0].decode(), text1) 61 | self.assertEqual(fields[1].decode(), text2) 62 | 63 | 64 | if "__main__" == __name__: 65 | unittest.main() 66 | 67 | -------------------------------------------------------------------------------- /libs/ib_client/tests/test_enum_implem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | import unittest 7 | 8 | from ibapi.enum_implem import Enum 9 | 10 | 11 | class EnumTestCase(unittest.TestCase): 12 | def setUp(self): 13 | pass 14 | 15 | 16 | def tearDown(self): 17 | pass 18 | 19 | 20 | def test_enum(self): 21 | e = Enum("ZERO", "ONE", "TWO") 22 | print(e.ZERO) 23 | print(e.to_str(e.ZERO)) 24 | 25 | 26 | if "__main__" == __name__: 27 | unittest.main() 28 | 29 | -------------------------------------------------------------------------------- /libs/ib_client/tests/test_order_conditions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | import unittest 7 | 8 | from ibapi.order_condition import * 9 | 10 | 11 | 12 | class ConditionOrderTestCase(unittest.TestCase): 13 | conds = [ 14 | VolumeCondition(8314, "SMART", True, 1000000).And(), 15 | PercentChangeCondition(1111, "AMEX", True, 0.25).Or(), 16 | PriceCondition( 17 | PriceCondition.TriggerMethodEnum.DoubleLast, 18 | 2222, "NASDAQ", False, 4.75).And(), 19 | TimeCondition(True, "20170101 09:30:00").And(), 20 | MarginCondition(False, 200000).Or(), 21 | ExecutionCondition("STK", "SMART", "AMD") 22 | ] 23 | 24 | for cond in conds: 25 | print(cond, OrderCondition.__str__(cond)) 26 | 27 | 28 | if "__main__" == __name__: 29 | unittest.main() 30 | 31 | -------------------------------------------------------------------------------- /libs/ib_client/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 Interactive Brokers LLC. All rights reserved. This code is subject to the terms 3 | and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. 4 | """ 5 | 6 | import unittest 7 | 8 | from ibapi.enum_implem import Enum 9 | from ibapi.utils import setattr_log 10 | 11 | 12 | class UtilsTestCase(unittest.TestCase): 13 | def setUp(self): 14 | pass 15 | 16 | 17 | def tearDown(self): 18 | pass 19 | 20 | 21 | def test_enum(self): 22 | e = Enum("ZERO", "ONE", "TWO") 23 | print(e.ZERO) 24 | print(e.to_str(e.ZERO)) 25 | 26 | 27 | def test_setattr_log(self): 28 | class A: 29 | def __init__(self): 30 | self.n = 5 31 | 32 | A.__setattr__ = setattr_log 33 | a = A() 34 | print(a.n) 35 | a.n = 6 36 | print(a.n) 37 | 38 | 39 | def test_polymorphism(self): 40 | class A: 41 | def __init__(self): 42 | self.n = 5 43 | def m(self): 44 | self.n += 1 45 | class B(A): 46 | def m(self): 47 | self.n += 2 48 | 49 | o = B() 50 | #import code; code.interact(local=locals()) 51 | 52 | 53 | if "__main__" == __name__: 54 | unittest.main() 55 | 56 | -------------------------------------------------------------------------------- /libs/ib_client/tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py33, py34, py35 3 | 4 | [testenv] 5 | deps = pytest 6 | commands = py.test {posargs} 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "algorithmic-trader" 3 | description = "Trading bot with support for realtime trading, backtesting, custom strategies and much more" 4 | authors = ["Idan Yael"] 5 | maintainers = ["Idan Yael"] 6 | packages = [{ include = "algotrader", from = "src" }] 7 | readme = "README.md" 8 | version = "0.0.0" 9 | keywords = ["algo-trader", "trading", "backtesting", "strategy", "bot"] 10 | license = "MIT" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "Programming Language :: Python :: 3.10", 14 | "Operating System :: OS Independent", 15 | ] 16 | [tool.poetry.group.dev.dependencies] 17 | coverage = "^7.3.2" 18 | pytest = "^7.4.2" 19 | pytest-cov = "^4.1.0" 20 | ruff = "^0.0.292" 21 | black = "^23.9.1" 22 | pyright = "^1.1.331" 23 | 24 | [tool.ruff] 25 | line-length = 120 # Same as Black. 26 | select = ["E", "F"] # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. 27 | ignore = [] # Allow autofix for all enabled rules (when `--fix`) is provided. 28 | fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] 29 | unfixable = [] 30 | exclude = [".bzr", ".direnv", ".eggs", ".git", ".git-rewrite", ".hg", ".mypy_cache", ".nox", ".pants.d", ".pytype", ".ruff_cache", ".svn", ".tox", ".venv", "__pypackages__", "_build", "buck-out", "build", "dist", "node_modules", "venv",] 31 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" # Allow unused variables when underscore-prefixed. 32 | target-version = "py311" 33 | mccabe = {max-complexity=10} 34 | 35 | [tool.black] 36 | line-length = 120 37 | preview = true 38 | target-version = ['py311'] 39 | 40 | 41 | [tool.poetry-dynamic-versioning] 42 | enable = true 43 | 44 | [tool.poetry.dependencies] 45 | python = ">=3.10,<3.13" 46 | newtulipy = "0.4.6" 47 | pymongo = "4.6.0" 48 | mongomock = "4.1.2" 49 | scipy = "1.11.4" 50 | yfinance = "0.2.32" 51 | typer = { version = "0.9.0", extras = ["all"] } 52 | coverage = "7.3.2" 53 | binance-connector = "1.18.0" 54 | python-dotenv = "1.0.0" 55 | ibapi = {path = "libs/ib_client"} 56 | pydantic = "^2.4.2" 57 | 58 | [build-system] 59 | requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] 60 | build-backend = "poetry_dynamic_versioning.backend" 61 | 62 | [project.scripts] 63 | algo-trader = "algotrader.cli.main:initiate_cli" 64 | 65 | [tool.setuptools_scm] 66 | 67 | [tool.setuptools.packages.find] 68 | where = ["src"] 69 | exclude = ["tests", "design", "build", "dist", "scripts"] 70 | 71 | [project.urls] 72 | homepage = "https://github.com/idanya/algo-trader" 73 | repository = "https://github.com/idanya/algo-trader" 74 | documentation = "https://github.com/idanya/algo-trader/blob/main/README.md" 75 | bug-tracker = "https://github.com/idanya/algo-trader/issues" 76 | 77 | [tool.pytest.ini_options] 78 | pythonpath = ["src"] 79 | 80 | [tool.pyright] 81 | include = ["src/"] 82 | exclude = ["**/node_modules", 83 | "src/algotrader/providers/ib/", 84 | "**/__pycache__", 85 | "libs/", 86 | ] 87 | defineConstant = { DEBUG = true } 88 | venv = "env311" 89 | 90 | reportMissingImports = true 91 | reportMissingTypeStubs = false 92 | 93 | pythonVersion = "3.11" 94 | pythonPlatform = "Linux" 95 | 96 | executionEnvironments = [ 97 | { root = "src" } 98 | ] -------------------------------------------------------------------------------- /src/algotrader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/__init__.py -------------------------------------------------------------------------------- /src/algotrader/assets/assets_provider.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | from typing import List 3 | 4 | SP500_SYMBOLS = "sp500.symbols" 5 | CRYPTO_SYMBOLS = "crypto.symbols" 6 | 7 | 8 | class AssetsProvider: 9 | @staticmethod 10 | def get_sp500_symbols() -> List[str]: 11 | return AssetsProvider._get_file_lines(SP500_SYMBOLS) 12 | 13 | @staticmethod 14 | def get_crypto_symbols() -> List[str]: 15 | return AssetsProvider._get_file_lines(CRYPTO_SYMBOLS) 16 | 17 | @staticmethod 18 | def _get_file_lines(filename: str) -> List[str]: 19 | symbols_file = path.join(path.dirname(__file__), filename) 20 | with open(symbols_file, "r") as file: 21 | return [line.rstrip("\n") for line in file] 22 | -------------------------------------------------------------------------------- /src/algotrader/assets/correlation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "groups": [ 3 | ["AAPL", "MSFT"], 4 | ["CHKP", "FB"] 5 | ] 6 | } -------------------------------------------------------------------------------- /src/algotrader/assets/crypto.symbols: -------------------------------------------------------------------------------- 1 | BTCUSDT -------------------------------------------------------------------------------- /src/algotrader/calc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/calc/__init__.py -------------------------------------------------------------------------------- /src/algotrader/calc/calculations.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class TechnicalCalculation(Enum): 5 | TYPICAL = "typical" 6 | SMA = "sma" 7 | CCI = "cci" 8 | MACD = "macd" 9 | RSI = "rsi" 10 | ADXR = "adxr" 11 | STDDEV = "stddev" 12 | EMA = "ema" 13 | MOM = "mom" 14 | NATR = "natr" 15 | MEANDEV = "meandev" 16 | OBV = "obv" 17 | VAR = "var" 18 | VOSC = "vosc" 19 | STOCH = "stoch" 20 | FISHER = "fisher" 21 | AROONOSC = "aroonosc" 22 | BBANDS = "bbands" 23 | -------------------------------------------------------------------------------- /src/algotrader/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/cli/__init__.py -------------------------------------------------------------------------------- /src/algotrader/cli/helpers.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Set, List 3 | 4 | import algotrader.pipeline 5 | import algotrader.pipeline.processors 6 | import algotrader.pipeline.sources 7 | import algotrader.pipeline.strategies 8 | 9 | 10 | def _get_all_of_class(base_class): 11 | results: Set[str] = set() 12 | 13 | def list_module_childs(m): 14 | for name, obj in inspect.getmembers(m): 15 | if inspect.ismodule(obj) and obj.__name__.startswith(m.__name__): 16 | list_module_childs(obj) 17 | elif inspect.isclass(obj) and issubclass(obj, base_class) and obj.__name__ != base_class.__name__: 18 | results.add(obj) 19 | 20 | list_module_childs(algotrader.pipeline) 21 | return results 22 | 23 | 24 | def _get_all_of_class_names(base_class) -> List[str]: 25 | return [p.__name__ for p in _get_all_of_class(base_class)] 26 | 27 | 28 | def _get_single_by_name(base_class, name: str): 29 | return next(filter(lambda p: p.__name__ == name, _get_all_of_class(base_class))) 30 | 31 | 32 | def _describe_object(obj): 33 | if obj.__doc__: 34 | print(f"Description: {obj.__doc__}") 35 | if obj.__init__.__doc__: 36 | print(f"Parameters: {obj.__init__.__doc__}") 37 | -------------------------------------------------------------------------------- /src/algotrader/cli/main.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | from algotrader.cli import processors, strategies, sources, pipeline 4 | 5 | app = typer.Typer(no_args_is_help=True) 6 | app.add_typer(processors.app, name="processor", short_help="Processors related commands") 7 | app.add_typer(strategies.app, name="strategy", short_help="Strategies related commands") 8 | app.add_typer(sources.app, name="source", short_help="Sources related commands") 9 | app.add_typer(pipeline.app, name="pipeline", short_help="Pipelines related commands") 10 | 11 | 12 | def initiate_cli(): 13 | app() 14 | -------------------------------------------------------------------------------- /src/algotrader/cli/pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import typer 4 | 5 | from algotrader.pipeline.pipeline import Pipeline 6 | from algotrader.pipeline.runner import PipelineRunner 7 | from algotrader.serialization.store import DeserializationService 8 | 9 | app = typer.Typer(no_args_is_help=True) 10 | 11 | 12 | def load_pipeline_spec(file_path: str) -> Pipeline: 13 | with open(file_path, "r") as input_file: 14 | return DeserializationService.deserialize(json.loads(input_file.read())) 15 | 16 | 17 | @app.command() 18 | def run(path: str): 19 | """ 20 | Create and run a JSON serialized pipeline from file 21 | """ 22 | pipeline = load_pipeline_spec(path) 23 | PipelineRunner(pipeline).run() 24 | -------------------------------------------------------------------------------- /src/algotrader/cli/processors.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | from algotrader.cli.helpers import _describe_object, _get_all_of_class_names, _get_single_by_name 4 | from algotrader.pipeline.processor import Processor 5 | 6 | app = typer.Typer(no_args_is_help=True) 7 | 8 | 9 | @app.command() 10 | def list(): 11 | print("\n".join(_get_all_of_class_names(Processor))) 12 | 13 | 14 | @app.command() 15 | def describe(name: str): 16 | _describe_object(_get_single_by_name(Processor, name)) 17 | -------------------------------------------------------------------------------- /src/algotrader/cli/sources.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | from algotrader.cli.helpers import _describe_object, _get_all_of_class_names, _get_single_by_name 4 | from algotrader.pipeline.source import Source 5 | 6 | app = typer.Typer(no_args_is_help=True) 7 | 8 | 9 | @app.command() 10 | def list(): 11 | print("\n".join(_get_all_of_class_names(Source))) 12 | 13 | 14 | @app.command() 15 | def describe(name: str): 16 | _describe_object(_get_single_by_name(Source, name)) 17 | -------------------------------------------------------------------------------- /src/algotrader/cli/strategies.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | from algotrader.cli.helpers import _describe_object, _get_all_of_class_names, _get_single_by_name 4 | from algotrader.entities.strategy import Strategy 5 | 6 | app = typer.Typer(no_args_is_help=True) 7 | 8 | 9 | @app.command() 10 | def list(): 11 | print("\n".join(_get_all_of_class_names(Strategy))) 12 | 13 | 14 | @app.command() 15 | def describe(name: str): 16 | strategy = _get_single_by_name(Strategy, name) 17 | _describe_object(strategy) 18 | -------------------------------------------------------------------------------- /src/algotrader/entities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/entities/__init__.py -------------------------------------------------------------------------------- /src/algotrader/entities/attachments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/entities/attachments/__init__.py -------------------------------------------------------------------------------- /src/algotrader/entities/attachments/assets_correlation.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from algotrader.entities.attachments.technicals import IndicatorValue 4 | from algotrader.entities.generic_candle_attachment import GenericCandleAttachment 5 | 6 | 7 | class AssetCorrelation(GenericCandleAttachment[IndicatorValue]): 8 | type: Literal["AssetCorrelation"] = "AssetCorrelation" 9 | -------------------------------------------------------------------------------- /src/algotrader/entities/attachments/nothing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Literal 4 | 5 | from pydantic import Field 6 | 7 | from algotrader.entities.base_dto import BaseEntity 8 | 9 | 10 | class NothingClass(BaseEntity): 11 | type: Literal["NothingClass"] = "NothingClass" 12 | nothing: str = Field(default="nothing-at-all") 13 | -------------------------------------------------------------------------------- /src/algotrader/entities/attachments/returns.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Literal 4 | 5 | from algotrader.entities.generic_candle_attachment import GenericCandleAttachment 6 | 7 | 8 | class Returns(GenericCandleAttachment[float]): 9 | type: Literal["Returns"] = "Returns" 10 | -------------------------------------------------------------------------------- /src/algotrader/entities/attachments/technicals.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Literal, Union, List 4 | 5 | from algotrader.entities.generic_candle_attachment import GenericCandleAttachment 6 | 7 | IndicatorValue = Union[List[float], float] 8 | 9 | 10 | class Indicators(GenericCandleAttachment[IndicatorValue]): 11 | type: Literal["Indicators"] = "Indicators" 12 | -------------------------------------------------------------------------------- /src/algotrader/entities/attachments/technicals_buckets_matcher.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Literal 2 | 3 | from algotrader.entities.bucket import Bucket 4 | from algotrader.entities.generic_candle_attachment import GenericCandleAttachment 5 | 6 | 7 | class IndicatorsMatchedBuckets(GenericCandleAttachment[Union[List[Bucket], Bucket]]): 8 | type: Literal["IndicatorsMatchedBuckets"] = "IndicatorsMatchedBuckets" 9 | -------------------------------------------------------------------------------- /src/algotrader/entities/attachments/technicals_normalizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Literal 4 | 5 | from algotrader.entities.attachments.technicals import IndicatorValue 6 | from algotrader.entities.generic_candle_attachment import GenericCandleAttachment 7 | 8 | 9 | class NormalizedIndicators(GenericCandleAttachment[IndicatorValue]): 10 | type: Literal["NormalizedIndicators"] = "NormalizedIndicators" 11 | -------------------------------------------------------------------------------- /src/algotrader/entities/base_dto.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import ClassVar 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | class BaseEntity(BaseModel): 8 | _types: ClassVar[dict[str, type]] = {} 9 | 10 | created_at: datetime = Field(default_factory=datetime.utcnow) 11 | updated_at: datetime = Field(default_factory=datetime.utcnow) 12 | -------------------------------------------------------------------------------- /src/algotrader/entities/bucket.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from typing import List, Union, Optional 5 | 6 | from pydantic import Field 7 | 8 | from algotrader.entities.base_dto import BaseEntity 9 | 10 | 11 | class Bucket(BaseEntity): 12 | ident: float 13 | start: Optional[float] = Field(default=-math.inf) 14 | end: Optional[float] = Field(default=math.inf) 15 | 16 | @property 17 | def get_start(self): 18 | return self.start or -math.inf 19 | 20 | @property 21 | def get_end(self): 22 | return self.end or math.inf 23 | 24 | 25 | BucketList = List[Bucket] 26 | CompoundBucketList = Union[List[BucketList], BucketList] 27 | -------------------------------------------------------------------------------- /src/algotrader/entities/bucketscontainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Dict, Optional, ItemsView 4 | 5 | from pydantic import Field 6 | 7 | from algotrader.entities.base_dto import BaseEntity 8 | from algotrader.entities.bucket import Bucket, CompoundBucketList 9 | from algotrader.serialization.store import DeserializationService 10 | 11 | 12 | class BucketsContainer(BaseEntity): 13 | bins: dict[str, CompoundBucketList] = Field(default_factory=dict) 14 | 15 | def items(self) -> ItemsView[str, CompoundBucketList]: 16 | return self.bins.items() 17 | 18 | def add(self, indicator: str, value: CompoundBucketList): 19 | self.bins[indicator] = value 20 | 21 | def get(self, indicator: str) -> Optional[CompoundBucketList]: 22 | return self.bins[indicator] if indicator in self.bins else None 23 | 24 | def serialize(self) -> Dict: 25 | data = super().serialize() 26 | for key, value in self.bins.items(): 27 | if isinstance(value[0], list): 28 | data[key] = [] 29 | for arr in value: 30 | data[key].append([x.serialize() for x in arr]) 31 | elif isinstance(value[0], Bucket): 32 | data[key] = [x.serialize() for x in value] 33 | 34 | return data 35 | 36 | @classmethod 37 | def deserialize(cls, data: Dict) -> BucketsContainer: 38 | bins = BucketsContainer() 39 | for key, value in data.items(): 40 | if key == "__class__": 41 | continue 42 | 43 | if isinstance(value[0], list): 44 | lists = [] 45 | for lst in value: 46 | lists.append([DeserializationService.deserialize(x) for x in lst]) 47 | 48 | bins.add(key, lists) 49 | else: 50 | bins.add(key, [DeserializationService.deserialize(x) for x in value]) 51 | 52 | return bins 53 | 54 | 55 | BucketsContainer() 56 | -------------------------------------------------------------------------------- /src/algotrader/entities/candle.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from datetime import datetime 4 | from typing import Optional, Annotated, Literal 5 | 6 | from pydantic import Field 7 | 8 | from algotrader.entities.base_dto import BaseEntity 9 | from algotrader.entities.candle_attachments import CandleAttachment 10 | from algotrader.entities.timespan import TimeSpan 11 | 12 | 13 | def timestamp_to_str(d: datetime) -> str: 14 | return d.strftime("%Y%m%d %H:%M:%S.%f") 15 | 16 | 17 | def str_to_timestamp(d: str) -> datetime: 18 | return datetime.strptime(d, "%Y%m%d %H:%M:%S.%f") 19 | 20 | 21 | class Candle(BaseEntity): 22 | type: Literal["Candle"] = "Candle" 23 | symbol: str 24 | timestamp: datetime 25 | time_span: TimeSpan 26 | 27 | open: float 28 | close: float 29 | high: float 30 | low: float 31 | volume: float 32 | 33 | attachments: Optional[dict[str, Annotated[CandleAttachment, Field(discriminator="type")]]] = None 34 | 35 | def add_attachment(self, key: str, entity: BaseEntity): 36 | if not self.attachments: 37 | self.attachments = {} 38 | 39 | self.attachments[key] = entity 40 | 41 | def get_attachment(self, key: str) -> Optional[CandleAttachment]: 42 | if self.attachments: 43 | return self.attachments.get(key) 44 | -------------------------------------------------------------------------------- /src/algotrader/entities/candle_attachments.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Dict, Optional, Union 4 | 5 | from algotrader.entities.attachments.nothing import NothingClass 6 | from algotrader.entities.serializable import Serializable, Deserializable 7 | from algotrader.entities.attachments.technicals_buckets_matcher import IndicatorsMatchedBuckets 8 | from algotrader.entities.attachments.assets_correlation import AssetCorrelation 9 | from algotrader.entities.attachments.technicals import Indicators 10 | from algotrader.entities.attachments.technicals_normalizer import NormalizedIndicators 11 | from algotrader.serialization.store import DeserializationService 12 | 13 | CandleAttachment = Union[NothingClass, NormalizedIndicators, Indicators, AssetCorrelation, IndicatorsMatchedBuckets] 14 | 15 | 16 | class CandleAttachments(Serializable, Deserializable): 17 | def __init__(self) -> None: 18 | super().__init__() 19 | self.data: Dict[str, Serializable] = {} 20 | 21 | @classmethod 22 | def deserialize(cls, data: Dict): 23 | obj = CandleAttachments() 24 | for k, v in data.items(): 25 | if k != "__class__" and isinstance(v, dict) and "__class__" in v: 26 | obj.add_attachement(k, DeserializationService.deserialize(v)) 27 | 28 | return obj 29 | 30 | def add_attachement(self, key: str, data: Serializable): 31 | self.data[key] = data 32 | 33 | def get_attachment(self, key: str) -> Optional[Serializable]: 34 | return self.data.get(key, None) 35 | 36 | def serialize(self) -> Dict: 37 | obj = super().serialize() 38 | 39 | for k, v in self.data.items(): 40 | if v: 41 | obj[k] = v.serialize() 42 | 43 | return obj 44 | -------------------------------------------------------------------------------- /src/algotrader/entities/event.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Event(Enum): 5 | TimeSpanChange = "TimeSpanChange" 6 | -------------------------------------------------------------------------------- /src/algotrader/entities/generic_candle_attachment.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Dict, TypeVar, Generic, Optional, ItemsView 4 | 5 | from pydantic import Field 6 | 7 | from algotrader.entities.base_dto import BaseEntity 8 | 9 | T = TypeVar("T") 10 | 11 | 12 | class GenericCandleAttachment(Generic[T], BaseEntity): 13 | data: Dict[str, T] = Field(default_factory=dict) 14 | 15 | def __getitem__(self, key): 16 | return self.data[key] 17 | 18 | def set(self, key: str, value: T): 19 | self.data[key] = value 20 | 21 | def get(self, key: str) -> Optional[T]: 22 | return self.data[key] 23 | 24 | def items(self) -> ItemsView[str, T]: 25 | data = {} 26 | for k, v in self.data.items(): 27 | if k == "__class__": 28 | continue 29 | data.update({k: v}) 30 | 31 | return data.items() 32 | 33 | def has(self, key: str): 34 | return key in self.data and self.data[key] is not None 35 | -------------------------------------------------------------------------------- /src/algotrader/entities/order_direction.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class OrderDirection(Enum): 5 | Buy = 1 6 | Sell = 2 7 | -------------------------------------------------------------------------------- /src/algotrader/entities/serializable.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | class Serializable: 5 | def serialize(self) -> Dict: 6 | module = self.__class__.__module__ 7 | name = self.__class__.__name__ 8 | return {"__class__": f"{module}:{name}"} 9 | 10 | 11 | class Deserializable: 12 | @classmethod 13 | def deserialize(cls, data: Dict): 14 | return cls() 15 | -------------------------------------------------------------------------------- /src/algotrader/entities/strategy.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import abstractmethod 4 | from typing import List 5 | 6 | from algotrader.entities.candle import Candle 7 | from algotrader.entities.serializable import Deserializable, Serializable 8 | from algotrader.entities.strategy_signal import StrategySignal 9 | from algotrader.pipeline.shared_context import SharedContext 10 | 11 | 12 | class Strategy(Serializable, Deserializable): 13 | def __init__(self): 14 | pass 15 | 16 | @abstractmethod 17 | def process(self, context: SharedContext, candle: Candle) -> List[StrategySignal]: 18 | pass 19 | -------------------------------------------------------------------------------- /src/algotrader/entities/strategy_signal.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class SignalDirection(Enum): 5 | Long = 1 6 | Short = 2 7 | 8 | 9 | class StrategySignal: 10 | def __init__(self, symbol: str, direction: SignalDirection) -> None: 11 | self.symbol = symbol 12 | self.direction = direction 13 | -------------------------------------------------------------------------------- /src/algotrader/entities/timespan.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum, unique 2 | 3 | 4 | @unique 5 | class TimeSpan(IntEnum): 6 | Second = 1 7 | Minute = 60 * Second 8 | Hour = 60 * Minute 9 | Day = 24 * Hour 10 | -------------------------------------------------------------------------------- /src/algotrader/examples/mongo-history-by-buckets.md: -------------------------------------------------------------------------------- 1 | ```json 2 | [ 3 | { 4 | '$match': { 5 | "attachments.indicators_matched_buckets.sma5.ident": { 6 | '$exists': true 7 | }, 8 | "attachments.indicators_matched_buckets.sma20.ident": { 9 | '$exists': true 10 | }, 11 | "attachments.returns.ctc1": { 12 | '$exists': true 13 | } 14 | } 15 | }, 16 | { 17 | "$group": { 18 | "_id": { 19 | "sma5": "$attachments.indicators_matched_buckets.sma5.ident", 20 | "sma20": "$attachments.indicators_matched_buckets.sma20.ident" 21 | }, 22 | "avg": { 23 | "$avg": "$attachments.returns.ctc1" 24 | }, 25 | "count": { 26 | "$sum": 1 27 | } 28 | } 29 | }, 30 | { 31 | '$match': { 32 | "count": { 33 | '$gte': 1000 34 | }, 35 | "avg": { 36 | '$gte': 0 37 | } 38 | } 39 | } 40 | ] 41 | ``` 42 | -------------------------------------------------------------------------------- /src/algotrader/examples/pipeline-templates/bins.json: -------------------------------------------------------------------------------- 1 | {"__class__": "entities.bucketscontainer:BucketsContainer"} -------------------------------------------------------------------------------- /src/algotrader/examples/pipeline-templates/build_daily_binance_loader.json: -------------------------------------------------------------------------------- 1 | { 2 | "__class__": "algotrader.pipeline.pipeline:Pipeline", 3 | "source": { 4 | "__class__": "algotrader.pipeline.sources.binance_history:BinanceHistorySource", 5 | "binanceProvider": { 6 | "apiKey": "", 7 | "apiSecret": "", 8 | "enableWebsocket": false 9 | }, 10 | "symbols": [ 11 | "BTCUSDT" 12 | ], 13 | "timeSpan": 86400, 14 | "startTime": 1609452000.0, 15 | "endTime": 1640988000.0 16 | }, 17 | "processor": { 18 | "__class__": "algotrader.pipeline.processors.technicals:TechnicalsProcessor", 19 | "next_processor": { 20 | "__class__": "algotrader.pipeline.processors.candle_cache:CandleCache", 21 | "next_processor": { 22 | "__class__": "algotrader.pipeline.processors.storage_provider_sink:StorageSinkProcessor", 23 | "storage_provider": { 24 | "__class__": "algotrader.storage.mongodb_storage:MongoDBStorage", 25 | "host": "localhost", 26 | "port": 27017, 27 | "database": "algo-trader", 28 | "username": "root", 29 | "password": "root" 30 | } 31 | } 32 | }, 33 | "config": { 34 | "technicals": [ 35 | { 36 | "name": "sma5", 37 | "type": "sma", 38 | "params": [ 39 | 5 40 | ] 41 | }, 42 | { 43 | "name": "sma20", 44 | "type": "sma", 45 | "params": [ 46 | 20 47 | ] 48 | }, 49 | { 50 | "name": "cci7", 51 | "type": "cci", 52 | "params": [ 53 | 7 54 | ] 55 | }, 56 | { 57 | "name": "cci14", 58 | "type": "cci", 59 | "params": [ 60 | 14 61 | ] 62 | }, 63 | { 64 | "name": "macd", 65 | "type": "macd", 66 | "params": [ 67 | 2, 68 | 5, 69 | 9 70 | ] 71 | }, 72 | { 73 | "name": "rsi7", 74 | "type": "cci", 75 | "params": [ 76 | 7 77 | ] 78 | }, 79 | { 80 | "name": "rsi14", 81 | "type": "cci", 82 | "params": [ 83 | 14 84 | ] 85 | }, 86 | { 87 | "name": "adxr5", 88 | "type": "adxr", 89 | "params": [ 90 | 5 91 | ] 92 | }, 93 | { 94 | "name": "stddev5", 95 | "type": "stddev", 96 | "params": [ 97 | 5 98 | ] 99 | }, 100 | { 101 | "name": "ema5", 102 | "type": "ema", 103 | "params": [ 104 | 5 105 | ] 106 | }, 107 | { 108 | "name": "ema20", 109 | "type": "ema", 110 | "params": [ 111 | 20 112 | ] 113 | }, 114 | { 115 | "name": "mom5", 116 | "type": "mom", 117 | "params": [ 118 | 5 119 | ] 120 | }, 121 | { 122 | "name": "natr5", 123 | "type": "natr", 124 | "params": [ 125 | 5 126 | ] 127 | }, 128 | { 129 | "name": "meandev5", 130 | "type": "meandev", 131 | "params": [ 132 | 5 133 | ] 134 | }, 135 | { 136 | "name": "obv", 137 | "type": "obv", 138 | "params": [] 139 | }, 140 | { 141 | "name": "var5", 142 | "type": "var", 143 | "params": [ 144 | 5 145 | ] 146 | }, 147 | { 148 | "name": "vosc", 149 | "type": "vosc", 150 | "params": [ 151 | 2, 152 | 5 153 | ] 154 | } 155 | ] 156 | } 157 | }, 158 | "terminator": null 159 | } -------------------------------------------------------------------------------- /src/algotrader/examples/pipeline-templates/build_realtime_binance.json: -------------------------------------------------------------------------------- 1 | { 2 | "__class__": "algotrader.pipeline.pipeline:Pipeline", 3 | "source": { 4 | "__class__": "algotrader.pipeline.sources.binance_realtime:BinanceRealtimeSource", 5 | "binanceProvider": { 6 | "apiKey": "", 7 | "apiSecret": "", 8 | "enableWebsocket": true 9 | }, 10 | "symbols": [ 11 | "BTCUSDT" 12 | ], 13 | "timeSpan": 1 14 | }, 15 | "processor": { 16 | "__class__": "algotrader.pipeline.processors.technicals:TechnicalsProcessor", 17 | "next_processor": { 18 | "__class__": "algotrader.pipeline.processors.candle_cache:CandleCache", 19 | "next_processor": { 20 | "__class__": "algotrader.pipeline.processors.storage_provider_sink:StorageSinkProcessor", 21 | "storage_provider": { 22 | "__class__": "algotrader.storage.mongodb_storage:MongoDBStorage", 23 | "host": "localhost", 24 | "port": 27017, 25 | "database": "algo-trader", 26 | "username": "root", 27 | "password": "root" 28 | } 29 | } 30 | }, 31 | "config": { 32 | "technicals": [ 33 | { 34 | "name": "sma5", 35 | "type": "sma", 36 | "params": [ 37 | 5 38 | ] 39 | }, 40 | { 41 | "name": "sma20", 42 | "type": "sma", 43 | "params": [ 44 | 20 45 | ] 46 | }, 47 | { 48 | "name": "cci7", 49 | "type": "cci", 50 | "params": [ 51 | 7 52 | ] 53 | }, 54 | { 55 | "name": "cci14", 56 | "type": "cci", 57 | "params": [ 58 | 14 59 | ] 60 | }, 61 | { 62 | "name": "macd", 63 | "type": "macd", 64 | "params": [ 65 | 2, 66 | 5, 67 | 9 68 | ] 69 | }, 70 | { 71 | "name": "rsi7", 72 | "type": "cci", 73 | "params": [ 74 | 7 75 | ] 76 | }, 77 | { 78 | "name": "rsi14", 79 | "type": "cci", 80 | "params": [ 81 | 14 82 | ] 83 | }, 84 | { 85 | "name": "adxr5", 86 | "type": "adxr", 87 | "params": [ 88 | 5 89 | ] 90 | }, 91 | { 92 | "name": "stddev5", 93 | "type": "stddev", 94 | "params": [ 95 | 5 96 | ] 97 | }, 98 | { 99 | "name": "ema5", 100 | "type": "ema", 101 | "params": [ 102 | 5 103 | ] 104 | }, 105 | { 106 | "name": "ema20", 107 | "type": "ema", 108 | "params": [ 109 | 20 110 | ] 111 | }, 112 | { 113 | "name": "mom5", 114 | "type": "mom", 115 | "params": [ 116 | 5 117 | ] 118 | }, 119 | { 120 | "name": "natr5", 121 | "type": "natr", 122 | "params": [ 123 | 5 124 | ] 125 | }, 126 | { 127 | "name": "meandev5", 128 | "type": "meandev", 129 | "params": [ 130 | 5 131 | ] 132 | }, 133 | { 134 | "name": "obv", 135 | "type": "obv", 136 | "params": [] 137 | }, 138 | { 139 | "name": "var5", 140 | "type": "var", 141 | "params": [ 142 | 5 143 | ] 144 | }, 145 | { 146 | "name": "vosc", 147 | "type": "vosc", 148 | "params": [ 149 | 2, 150 | 5 151 | ] 152 | } 153 | ] 154 | } 155 | }, 156 | "terminator": null 157 | } -------------------------------------------------------------------------------- /src/algotrader/examples/pipeline-templates/correlation.json: -------------------------------------------------------------------------------- 1 | { 2 | "groups": [ 3 | [ 4 | "AAPL", 5 | "MSFT", 6 | "TSM", 7 | "NVDA", 8 | "ASML", 9 | "ASMLF", 10 | "AVGO", 11 | "ADBE", 12 | "CSCO", 13 | "ACN", 14 | "ORCL", 15 | "CRM", 16 | "QCOM", 17 | "INTC", 18 | "INTU", 19 | "TXN", 20 | "AMD", 21 | "SAP", 22 | "SAPGF", 23 | "SONY", 24 | "SNEJF", 25 | "AMAT", 26 | "KYCCF", 27 | "IBM", 28 | "NOW", 29 | "SHOP" 30 | ] 31 | ] 32 | } -------------------------------------------------------------------------------- /src/algotrader/examples/pipeline-templates/loader_simple_daily_loader.json: -------------------------------------------------------------------------------- 1 | { 2 | "__class__": "pipeline.pipeline:Pipeline", 3 | "source": { 4 | "__class__": "pipeline.sources.ib_history:IBHistorySource" 5 | }, 6 | "processor": { 7 | "__class__": "pipeline.processors.technicals:TechnicalsProcessor", 8 | "next_processor": { 9 | "__class__": "pipeline.processors.candle_cache:CandleCache", 10 | "next_processor": { 11 | "__class__": "pipeline.processors.mongodb_sink:MongoDBSinkProcessor" 12 | } 13 | } 14 | }, 15 | "terminator": null 16 | } -------------------------------------------------------------------------------- /src/algotrader/logger/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pathlib 4 | import time 5 | 6 | 7 | def setup_default_logger(): 8 | log_dir = pathlib.Path(__file__).parent.parent.joinpath("logs").resolve() 9 | level = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO 10 | 11 | if not pathlib.Path.exists(log_dir): 12 | pathlib.Path.mkdir(log_dir) 13 | 14 | recfmt = "(%(threadName)s) %(asctime)s.%(msecs)03d %(levelname)s %(filename)s:%(lineno)d %(message)s" 15 | 16 | timefmt = "%y%m%d_%H:%M:%S" 17 | 18 | logging.basicConfig( 19 | filename=f'{log_dir}/{time.strftime("algo-trader.%y%m%d_%H%M%S.log")}', 20 | filemode="w", 21 | level=level, 22 | format=recfmt, 23 | datefmt=timefmt, 24 | ) 25 | logger = logging.getLogger() 26 | console = logging.StreamHandler() 27 | console.setLevel(level) 28 | logger.addHandler(console) 29 | -------------------------------------------------------------------------------- /src/algotrader/main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | 4 | from algotrader.cli.main import initiate_cli 5 | 6 | # from logger import setup_default_logger 7 | from algotrader.pipeline.builders.backtest import BacktestPipelines 8 | from algotrader.pipeline.builders.loaders import LoadersPipelines 9 | from algotrader.pipeline.pipeline import Pipeline 10 | 11 | BIN_COUNT = 10 12 | 13 | EXAMPLE_TEMPLATES_DIR = pathlib.Path(__file__).parent.joinpath("examples/pipeline-templates").resolve() 14 | """ 15 | Main entry point, you can use the LoadersPipelines or the BacktestPipelines in order to run an example pipeline. 16 | This should eventually be the CLI entrypoint. For now, it's for running examples. 17 | """ 18 | 19 | 20 | def save_pipeline_spec(filename: str, pipeline: Pipeline): 21 | if not pathlib.Path.exists(EXAMPLE_TEMPLATES_DIR): 22 | pathlib.Path.mkdir(EXAMPLE_TEMPLATES_DIR) 23 | 24 | with open(pathlib.Path(EXAMPLE_TEMPLATES_DIR).joinpath(filename), "w") as output_file: 25 | output_file.write(json.dumps(pipeline.serialize(), indent=2, default=str)) 26 | 27 | 28 | def generate_example_templates(): 29 | save_pipeline_spec("build_realtime_binance.json", LoadersPipelines.build_realtime_binance()) 30 | save_pipeline_spec("build_daily_binance_loader.json", LoadersPipelines.build_daily_binance_loader()) 31 | save_pipeline_spec("build_daily_yahoo_loader.json", LoadersPipelines.build_daily_yahoo_loader()) 32 | save_pipeline_spec("backtest_mongo_source_rsi_strategy.json", BacktestPipelines.build_mongodb_backtester()) 33 | save_pipeline_spec( 34 | "backtest_history_buckets_backtester.json", 35 | BacktestPipelines.build_mongodb_history_buckets_backtester(f"{EXAMPLE_TEMPLATES_DIR}/bins.json"), 36 | ) 37 | 38 | save_pipeline_spec( 39 | "backtest_technicals_with_buckets_calculator.json", 40 | LoadersPipelines.build_technicals_with_buckets_calculator( 41 | f"{EXAMPLE_TEMPLATES_DIR}/bins.json", BIN_COUNT, f"{EXAMPLE_TEMPLATES_DIR}/correlation.json" 42 | ), 43 | ) 44 | 45 | save_pipeline_spec("loader_simple_technicals_calculator.json", LoadersPipelines.build_technicals_calculator()) 46 | save_pipeline_spec("loader_simple_returns_calculator.json", LoadersPipelines.build_returns_calculator()) 47 | save_pipeline_spec( 48 | "loader_technicals_with_buckets_matcher.json", 49 | LoadersPipelines.build_technicals_with_buckets_matcher( 50 | f"{EXAMPLE_TEMPLATES_DIR}/bins.json", f"{EXAMPLE_TEMPLATES_DIR}/correlation.json" 51 | ), 52 | ) 53 | 54 | save_pipeline_spec( 55 | "backtest_history_similarity_backtester.json", 56 | BacktestPipelines.build_mongodb_history_similarity_backtester(f"{EXAMPLE_TEMPLATES_DIR}/bins.json"), 57 | ) 58 | 59 | # depends on a running IB gateway 60 | # save_pipeline_spec('loader_simple_daily_loader.json', LoadersPipelines.build_daily_ib_loader()) 61 | 62 | 63 | if __name__ == "__main__": 64 | # setup_default_logger() 65 | 66 | # generate_example_templates() 67 | 68 | initiate_cli() 69 | -------------------------------------------------------------------------------- /src/algotrader/market/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/market/__init__.py -------------------------------------------------------------------------------- /src/algotrader/market/async_market_provider.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from datetime import datetime 3 | 4 | from algotrader.entities.timespan import TimeSpan 5 | from algotrader.market.async_query_result import AsyncQueryResult 6 | 7 | 8 | class AsyncMarketProvider: 9 | @abstractmethod 10 | def request_symbol_history( 11 | self, symbol: str, candle_timespan: TimeSpan, from_time: datetime, to_time: datetime 12 | ) -> AsyncQueryResult: 13 | pass 14 | -------------------------------------------------------------------------------- /src/algotrader/market/async_query_result.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from datetime import datetime 3 | from typing import List 4 | 5 | from algotrader.entities.candle import Candle 6 | from algotrader.providers.ib.query_subscription import QuerySubscription 7 | 8 | 9 | class AsyncQueryResult: 10 | def __init__(self, from_timestamp: datetime, to_timestamp: datetime) -> None: 11 | self.from_timestamp = from_timestamp 12 | self.to_timestamp = to_timestamp 13 | self.done_event = threading.Event() 14 | self.subscriptions: List[QuerySubscription] = [] 15 | self.candles: List[Candle] = [] 16 | 17 | def attach_query_subscription(self, subscription: QuerySubscription): 18 | self.subscriptions.append(subscription) 19 | 20 | def result(self) -> List[Candle]: 21 | results = [sub.result() for sub in self.subscriptions] 22 | candles: List[Candle] = [] 23 | for res in results: 24 | candles += res 25 | 26 | filtered_candles = filter(lambda c: self.from_timestamp <= c.timestamp <= self.to_timestamp, candles) 27 | return sorted(filtered_candles, key=lambda c: c.timestamp) 28 | -------------------------------------------------------------------------------- /src/algotrader/market/ib_market.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from algotrader.entities.timespan import TimeSpan 4 | from algotrader.market.async_market_provider import AsyncMarketProvider, AsyncQueryResult 5 | from algotrader.providers.ib.interactive_brokers_connector import InteractiveBrokersConnector 6 | 7 | 8 | class IBMarketProvider(AsyncMarketProvider): 9 | def __init__(self, ib_connector: InteractiveBrokersConnector) -> None: 10 | super().__init__() 11 | self.ib_connector = ib_connector 12 | 13 | def request_symbol_history( 14 | self, symbol: str, candle_timespan: TimeSpan, from_time: datetime, to_time: datetime 15 | ) -> AsyncQueryResult: 16 | return self.ib_connector.request_symbol_history(symbol, candle_timespan, from_time, to_time) 17 | -------------------------------------------------------------------------------- /src/algotrader/market/market_provider.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class MarketProvider(ABC): 5 | pass 6 | -------------------------------------------------------------------------------- /src/algotrader/market/yahoofinance/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/market/yahoofinance/__init__.py -------------------------------------------------------------------------------- /src/algotrader/market/yahoofinance/history_provider.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List 3 | 4 | import yfinance as yf 5 | 6 | from algotrader.entities.candle import Candle 7 | from algotrader.entities.timespan import TimeSpan 8 | 9 | 10 | class YahooFinanceHistoryProvider: 11 | def get_symbol_history( 12 | self, 13 | symbol: str, 14 | period: TimeSpan, 15 | interval: TimeSpan, 16 | start_time: datetime, 17 | end_time: datetime, 18 | auto_adjust: bool = True, 19 | include_after_hours: bool = False, 20 | ) -> List[Candle]: 21 | """ 22 | @param symbol: symbol 23 | @param period: time span of each candle 24 | @param interval: interval between candles 25 | @param start_time: first candle time 26 | @param end_time: latest candle time 27 | @param auto_adjust: auto adjust closing price (dividends, splits) 28 | @param include_after_hours: include pre and post market data 29 | @return: List of candles 30 | """ 31 | ticker = yf.Ticker(symbol) 32 | data = ticker.history( 33 | self._translate_timespan(period), 34 | self._translate_timespan(interval), 35 | start_time, 36 | end_time, 37 | include_after_hours, 38 | False, 39 | auto_adjust, 40 | ) 41 | 42 | candles: List[Candle] = [] 43 | for index, row in data.iterrows(): 44 | candle = Candle( 45 | symbol, period, index.to_pydatetime(), row["Open"], row["Close"], row["High"], row["Low"], row["Volume"] 46 | ) 47 | candles.append(candle) 48 | 49 | return candles 50 | 51 | @staticmethod 52 | def _translate_timespan(span: TimeSpan) -> str: 53 | if span == TimeSpan.Day: 54 | return "1d" 55 | elif span == TimeSpan.Hour: 56 | return "1h" 57 | elif span == TimeSpan.Minute: 58 | return "1m" 59 | else: 60 | raise ValueError("minimum timespan for yahoo finance is 1m") 61 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/pipeline/__init__.py -------------------------------------------------------------------------------- /src/algotrader/pipeline/builders/__init__.py: -------------------------------------------------------------------------------- 1 | from algotrader.calc.calculations import TechnicalCalculation 2 | from algotrader.pipeline.configs.indicator_config import IndicatorConfig 3 | from algotrader.pipeline.configs.technical_processor_config import TechnicalsProcessorConfig 4 | 5 | TECHNICAL_PROCESSOR_CONFIG = TechnicalsProcessorConfig([ 6 | IndicatorConfig("sma5", TechnicalCalculation.SMA, [5]), 7 | IndicatorConfig("sma20", TechnicalCalculation.SMA, [20]), 8 | IndicatorConfig("cci7", TechnicalCalculation.CCI, [7]), 9 | IndicatorConfig("cci14", TechnicalCalculation.CCI, [14]), 10 | IndicatorConfig("macd", TechnicalCalculation.MACD, [2, 5, 9]), 11 | IndicatorConfig("rsi7", TechnicalCalculation.CCI, [7]), 12 | IndicatorConfig("rsi14", TechnicalCalculation.CCI, [14]), 13 | IndicatorConfig("adxr5", TechnicalCalculation.ADXR, [5]), 14 | IndicatorConfig("stddev5", TechnicalCalculation.STDDEV, [5]), 15 | IndicatorConfig("ema5", TechnicalCalculation.EMA, [5]), 16 | IndicatorConfig("ema20", TechnicalCalculation.EMA, [20]), 17 | IndicatorConfig("mom5", TechnicalCalculation.MOM, [5]), 18 | IndicatorConfig("natr5", TechnicalCalculation.NATR, [5]), 19 | IndicatorConfig("meandev5", TechnicalCalculation.MEANDEV, [5]), 20 | IndicatorConfig("obv", TechnicalCalculation.OBV, []), 21 | IndicatorConfig("var5", TechnicalCalculation.VAR, [5]), 22 | IndicatorConfig("vosc", TechnicalCalculation.VOSC, [2, 5]), 23 | ]) 24 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/pipeline/configs/__init__.py -------------------------------------------------------------------------------- /src/algotrader/pipeline/configs/indicator_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import List, Dict 4 | 5 | from algotrader.calc.calculations import TechnicalCalculation 6 | from algotrader.entities.serializable import Serializable, Deserializable 7 | 8 | 9 | class IndicatorConfig(Serializable, Deserializable): 10 | def __init__(self, name: str, calculation: TechnicalCalculation, params: List[any]): 11 | self.name = name 12 | self.type = calculation 13 | self.params = params 14 | 15 | def serialize(self) -> Dict: 16 | return {"name": self.name, "type": self.type.value, "params": self.params} 17 | 18 | @classmethod 19 | def deserialize(cls, data: Dict) -> IndicatorConfig: 20 | return IndicatorConfig(data["name"], TechnicalCalculation(data["type"]), data["params"]) 21 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/configs/technical_processor_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import List, Dict 4 | 5 | from algotrader.entities.serializable import Deserializable, Serializable 6 | from algotrader.pipeline.configs.indicator_config import IndicatorConfig 7 | 8 | 9 | class TechnicalsProcessorConfig(Serializable, Deserializable): 10 | def __init__(self, technicals: List[IndicatorConfig]): 11 | self.technicals = technicals 12 | 13 | def serialize(self) -> Dict: 14 | return {"technicals": [t.serialize() for t in self.technicals]} 15 | 16 | @classmethod 17 | def deserialize(cls, data: Dict) -> TechnicalsProcessorConfig: 18 | return TechnicalsProcessorConfig(technicals=[IndicatorConfig.deserialize(t) for t in data["technicals"]]) 19 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, Dict 3 | 4 | from algotrader.entities.serializable import Serializable, Deserializable 5 | from algotrader.pipeline.processor import Processor 6 | from algotrader.pipeline.shared_context import SharedContext 7 | from algotrader.pipeline.source import Source 8 | from algotrader.pipeline.terminator import Terminator 9 | from algotrader.serialization.store import DeserializationService 10 | 11 | 12 | class Pipeline(Serializable, Deserializable): 13 | logger = logging.getLogger("Pipeline") 14 | 15 | def __init__(self, source: Source, processor: Processor, terminator: Optional[Terminator] = None) -> None: 16 | self.source = source 17 | self.processor = processor 18 | self.terminator = terminator 19 | 20 | def serialize(self) -> Dict: 21 | obj = super().serialize() 22 | obj.update({ 23 | "source": self.source.serialize(), 24 | "processor": self.processor.serialize(), 25 | "terminator": self.terminator.serialize() if self.terminator else None, 26 | }) 27 | return obj 28 | 29 | @classmethod 30 | def deserialize(cls, data: Dict): 31 | return cls( 32 | DeserializationService.deserialize(data.get("source")), 33 | DeserializationService.deserialize(data.get("processor")), 34 | DeserializationService.deserialize(data.get("terminator")), 35 | ) 36 | 37 | def run(self, context: SharedContext) -> None: 38 | self.logger.info("Starting pipeline...") 39 | 40 | for candle in self.source.read(): 41 | self.logger.debug("Processing candle: %s\r", candle.model_dump()) 42 | self.processor.process(context, candle) 43 | 44 | if self.terminator: 45 | self.logger.debug("initiating termination...") 46 | self.terminator.terminate(context) 47 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/processor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import abstractmethod 4 | from typing import Optional, Dict 5 | 6 | from algotrader.entities.candle import Candle 7 | from algotrader.entities.event import Event 8 | from algotrader.entities.serializable import Deserializable, Serializable 9 | from algotrader.pipeline.shared_context import SharedContext 10 | from algotrader.serialization.store import DeserializationService 11 | 12 | 13 | class Processor(Serializable, Deserializable): 14 | def __init__(self, next_processor: Optional[Processor] = None) -> None: 15 | self.next_processor = next_processor 16 | 17 | @abstractmethod 18 | def process(self, context: SharedContext, candle: Candle): 19 | if self.next_processor: 20 | self.next_processor.process(context, candle) 21 | 22 | def reprocess(self, context: SharedContext, candle: Candle): 23 | if self.next_processor: 24 | self.next_processor.reprocess(context, candle) 25 | 26 | def event(self, context: SharedContext, event: Event): 27 | if self.next_processor: 28 | self.next_processor.event(context, event) 29 | 30 | @classmethod 31 | def deserialize(cls, data: Dict) -> Optional[Processor]: 32 | obj = cls(None) 33 | obj.next_processor = cls._deserialize_next_processor(data) 34 | return obj 35 | 36 | @classmethod 37 | def _deserialize_next_processor(cls, data: Dict) -> Optional[Processor]: 38 | if data.get("next_processor"): 39 | return DeserializationService.deserialize(data["next_processor"]) 40 | return None 41 | 42 | def serialize(self) -> Dict: 43 | obj = super().serialize() 44 | 45 | if self.next_processor: 46 | obj.update({"next_processor": self.next_processor.serialize()}) 47 | 48 | return obj 49 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/processors/__init__.py: -------------------------------------------------------------------------------- 1 | from algotrader.pipeline.processors.assets_correlation import AssetCorrelationProcessor 2 | from algotrader.pipeline.processors.candle_cache import CandleCache 3 | from algotrader.pipeline.processors.file_sink import FileSinkProcessor 4 | from algotrader.pipeline.processors.returns import ReturnsCalculatorProcessor 5 | from algotrader.pipeline.processors.storage_provider_sink import StorageSinkProcessor 6 | from algotrader.pipeline.processors.strategy import StrategyProcessor 7 | from algotrader.pipeline.processors.technicals import TechnicalsProcessor 8 | from algotrader.pipeline.processors.technicals_normalizer import TechnicalsNormalizerProcessor 9 | from algotrader.pipeline.processors.timespan_change import TimeSpanChangeProcessor 10 | 11 | __all__ = [ 12 | "AssetCorrelationProcessor", 13 | "CandleCache", 14 | "FileSinkProcessor", 15 | "ReturnsCalculatorProcessor", 16 | "StorageSinkProcessor", 17 | "StrategyProcessor", 18 | "TechnicalsProcessor", 19 | "TechnicalsNormalizerProcessor", 20 | "TimeSpanChangeProcessor", 21 | ] 22 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/processors/assets_correlation.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Optional, List, Dict 3 | 4 | from scipy import spatial 5 | 6 | from algotrader.entities.attachments.assets_correlation import AssetCorrelation 7 | from algotrader.entities.candle import Candle 8 | from algotrader.entities.event import Event 9 | from algotrader.pipeline.processor import Processor 10 | from algotrader.pipeline.processors.candle_cache import CandleCache 11 | from algotrader.pipeline.shared_context import SharedContext 12 | 13 | CORRELATIONS_ATTACHMENT_KEY = "correlations" 14 | CORRELATION_ELEMENTS_COUNT = 4 15 | 16 | 17 | class CorrelationConfig: 18 | def __init__(self, groups: List[List[str]]) -> None: 19 | self.groups: List[List[str]] = groups 20 | 21 | 22 | class AssetCorrelationProcessor(Processor): 23 | """ 24 | Calculates correlations between groups of symbols 25 | """ 26 | 27 | def __init__(self, config_path: str, next_processor: Optional[Processor]) -> None: 28 | """ 29 | @param config_path: path to the correlation's config file 30 | @param next_processor: the next processor in chain 31 | """ 32 | with open(config_path, "r") as config_content: 33 | c: Dict = json.loads(config_content.read()) 34 | self.config = CorrelationConfig(c.get("groups", [])) 35 | 36 | super().__init__(next_processor) 37 | 38 | def process(self, context: SharedContext, candle: Candle): 39 | super().process(context, candle) 40 | 41 | def event(self, context: SharedContext, event: Event): 42 | if event == event.TimeSpanChange: 43 | self._calculate_correlations(context) 44 | 45 | super().event(context, event) 46 | 47 | def _calculate_correlations(self, context: SharedContext): 48 | cache_reader = CandleCache.context_reader(context) 49 | symbols = cache_reader.get_symbols_list() 50 | 51 | for symbol in symbols: 52 | self._calculate_symbol_correlations(context, symbol) 53 | 54 | def _calculate_symbol_correlations(self, context: SharedContext, symbol: str): 55 | cache_reader = CandleCache.context_reader(context) 56 | asset_correlation = AssetCorrelation() 57 | 58 | group_symbols = self._get_symbol_group(symbol) 59 | 60 | if group_symbols: 61 | current_symbol_candles = cache_reader.get_symbol_candles(symbol) or [] 62 | current_symbol_values = self._get_correlation_measurable_values(current_symbol_candles) 63 | 64 | for paired_symbol in group_symbols: 65 | if paired_symbol == symbol: 66 | continue 67 | 68 | symbol_candles = cache_reader.get_symbol_candles(paired_symbol) or [] 69 | symbol_values = self._get_correlation_measurable_values(symbol_candles) 70 | 71 | if ( 72 | len(symbol_values) != len(current_symbol_values) 73 | or len(current_symbol_values) <= CORRELATION_ELEMENTS_COUNT 74 | ): 75 | continue 76 | 77 | correlation = spatial.distance.correlation( 78 | current_symbol_values[-CORRELATION_ELEMENTS_COUNT:], symbol_values[-CORRELATION_ELEMENTS_COUNT:] 79 | ) 80 | asset_correlation.set(paired_symbol, correlation) 81 | 82 | latest_candle = current_symbol_candles[-1] 83 | latest_candle.add_attachment(CORRELATIONS_ATTACHMENT_KEY, asset_correlation) 84 | 85 | self.reprocess(context, latest_candle) 86 | 87 | def _get_symbol_group(self, symbol: str) -> Optional[List[str]]: 88 | for group in self.config.groups: 89 | if symbol in group: 90 | return group 91 | 92 | @staticmethod 93 | def _get_correlation_measurable_values(candles: List[Candle]) -> List[float]: 94 | return [c.close for c in candles] 95 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/processors/candle_cache.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Dict 2 | 3 | from algotrader.entities.candle import Candle 4 | from algotrader.pipeline.processor import Processor 5 | from algotrader.pipeline.shared_context import SharedContext 6 | 7 | CONTEXT_IDENT = "CandleCache" 8 | CacheData = Dict[str, List[Candle]] 9 | 10 | 11 | class CandleCacheContextWriter: 12 | def __init__(self, context: SharedContext[CacheData]) -> None: 13 | super().__init__() 14 | self.context = context 15 | self.data: CacheData = {} 16 | 17 | def put_candle(self, candle: Candle): 18 | if not self.context.get_kv_data(CONTEXT_IDENT): 19 | self.context.put_kv_data(CONTEXT_IDENT, {}) 20 | 21 | self.data = self.context.get_kv_data(CONTEXT_IDENT) 22 | 23 | if candle.symbol not in self.data: 24 | self.data[candle.symbol] = [] 25 | 26 | self.data[candle.symbol].append(candle) 27 | 28 | 29 | class CandleCacheContextReader: 30 | def __init__(self, context: SharedContext[CacheData]) -> None: 31 | super().__init__() 32 | self.context = context 33 | 34 | def get_symbol_candles(self, symbol: str) -> Optional[List[Candle]]: 35 | data = self.context.get_kv_data(CONTEXT_IDENT) 36 | if data and symbol in data: 37 | return data[symbol] 38 | 39 | def get_symbols_list(self) -> Optional[List[str]]: 40 | data = self.context.get_kv_data(CONTEXT_IDENT) 41 | if data: 42 | return list(data.keys()) 43 | 44 | 45 | class CandleCache(Processor): 46 | """ 47 | Provides a cache facade for processed candles 48 | """ 49 | 50 | def __init__(self, next_processor: Optional[Processor] = None) -> None: 51 | super().__init__(next_processor) 52 | self.data: CacheData = {} 53 | 54 | def reprocess(self, context: SharedContext, candle: Candle): 55 | context_reader = CandleCacheContextReader(context) 56 | candles = context_reader.get_symbol_candles(candle.symbol) 57 | 58 | for i in range(len(candles)): 59 | if candles[i].timestamp == candle.timestamp: 60 | candles[i] = candle 61 | break 62 | 63 | super().reprocess(context, candle) 64 | 65 | def process(self, context: SharedContext, candle: Candle): 66 | context_writer = CandleCacheContextWriter(context) 67 | context_writer.put_candle(candle) 68 | 69 | super().process(context, candle) 70 | 71 | @staticmethod 72 | def context_reader(context: SharedContext) -> CandleCacheContextReader: 73 | return CandleCacheContextReader(context) 74 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/processors/file_sink.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Optional 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.pipeline.processor import Processor 6 | from algotrader.pipeline.shared_context import SharedContext 7 | 8 | 9 | class FileSinkProcessor(Processor): 10 | """ 11 | Write processed candles to file 12 | """ 13 | 14 | def __init__(self, file_path: str, next_processor: Optional[Processor] = None) -> None: 15 | """ 16 | @param file_path: file path to write to 17 | """ 18 | super().__init__(next_processor) 19 | self.file_path = file_path 20 | 21 | def process(self, context: SharedContext, candle: Candle): 22 | with open(self.file_path, "a") as output_file: 23 | line = self._generate_candle_output(context, candle) 24 | output_file.write(f"{line}\n") 25 | 26 | super().process(context, candle) 27 | 28 | def _generate_candle_output(self, context: SharedContext, candle: Candle) -> str: 29 | return json.dumps(candle.model_dump_json()) 30 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/processors/returns.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import List, Dict, Optional 4 | 5 | from algotrader.entities.attachments.returns import Returns 6 | from algotrader.entities.candle import Candle 7 | from algotrader.pipeline.processor import Processor 8 | from algotrader.pipeline.processors.candle_cache import CandleCache 9 | from algotrader.pipeline.shared_context import SharedContext 10 | 11 | RETURNS_ATTACHMENT_KEY = "returns" 12 | 13 | 14 | class ReturnsCalculatorProcessor(Processor): 15 | def __init__(self, field_prefix: str, returns_count: int, next_processor: Optional[Processor] = None): 16 | super().__init__(next_processor) 17 | self.field_prefix = field_prefix 18 | self.returns_count = returns_count 19 | 20 | def process(self, context: SharedContext, candle: Candle): 21 | cache_reader = CandleCache.context_reader(context) 22 | symbol_candles = cache_reader.get_symbol_candles(candle.symbol) or [] 23 | 24 | candle.add_attachment(RETURNS_ATTACHMENT_KEY, Returns()) 25 | 26 | if len(symbol_candles) >= self.returns_count: 27 | candle_returns = self._calc_returns(candle, symbol_candles) 28 | candle.add_attachment(RETURNS_ATTACHMENT_KEY, candle_returns) 29 | 30 | if self.next_processor: 31 | self.next_processor.process(context, candle) 32 | 33 | def _calc_returns(self, current_candle: Candle, candles: List[Candle]) -> Returns: 34 | candle_returns = Returns() 35 | for i in range(1, self.returns_count + 1): 36 | candle_returns.set(f"{self.field_prefix}-{i}", (1 - current_candle.close / candles[-i].close) * 100) 37 | 38 | return candle_returns 39 | 40 | def serialize(self) -> Dict: 41 | return {"returnsCount": self.returns_count, "fieldPrefix": self.field_prefix} 42 | 43 | @classmethod 44 | def deserialize(cls, data: Dict) -> Optional[Processor]: 45 | return cls(data["fieldPrefix"], data["returnsCount"], cls._deserialize_next_processor(data)) 46 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/processors/storage_provider_sink.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | 3 | from algotrader.entities.candle import Candle 4 | from algotrader.pipeline.processor import Processor 5 | from algotrader.pipeline.shared_context import SharedContext 6 | from algotrader.serialization.store import DeserializationService 7 | from algotrader.storage.storage_provider import StorageProvider 8 | 9 | 10 | class StorageSinkProcessor(Processor): 11 | """ 12 | Write all processed candles to a StorageProvider implementation 13 | """ 14 | 15 | def __init__(self, storage_provider: StorageProvider, next_processor: Optional[Processor] = None) -> None: 16 | """ 17 | @param storage_provider: StorageProvider implementation 18 | """ 19 | super().__init__(next_processor) 20 | self.storage_provider = storage_provider 21 | 22 | def process(self, context: SharedContext, candle: Candle): 23 | self.storage_provider.save(candle) 24 | super().process(context, candle) 25 | 26 | def reprocess(self, context: SharedContext, candle: Candle): 27 | self.storage_provider.save(candle) 28 | super().reprocess(context, candle) 29 | 30 | def serialize(self) -> Dict: 31 | obj = super().serialize() 32 | obj.update({ 33 | "storage_provider": self.storage_provider.serialize(), 34 | }) 35 | return obj 36 | 37 | @classmethod 38 | def deserialize(cls, data: Dict): 39 | storage_provider = DeserializationService.deserialize(data["storage_provider"]) 40 | return cls(storage_provider, cls._deserialize_next_processor(data)) 41 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/processors/strategy.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Dict 2 | 3 | from algotrader.entities.candle import Candle 4 | from algotrader.entities.strategy import Strategy 5 | from algotrader.entities.strategy_signal import StrategySignal 6 | from algotrader.pipeline.processor import Processor 7 | from algotrader.pipeline.shared_context import SharedContext 8 | from algotrader.serialization.store import DeserializationService 9 | from algotrader.trade.signals_executor import SignalsExecutor 10 | 11 | 12 | class StrategyProcessor(Processor): 13 | """ 14 | Main strategy processor. Receives a list of strategies and executes them all on each processed candle. 15 | Forward all strategies signals to a SignalsExecutor implementation 16 | """ 17 | 18 | def __init__( 19 | self, strategies: List[Strategy], signals_executor: SignalsExecutor, next_processor: Optional[Processor] 20 | ) -> None: 21 | """ 22 | @param strategies: List of strategies (Strategy implementations) 23 | @param signals_executor: SignalsExecutor implementation 24 | """ 25 | super().__init__(next_processor) 26 | self.signals_executor = signals_executor 27 | self.strategies = strategies 28 | 29 | def process(self, context: SharedContext, candle: Candle): 30 | signals: List[StrategySignal] = [] 31 | for strategy in self.strategies: 32 | signals += strategy.process(context, candle) or [] 33 | 34 | self.signals_executor.execute(candle, signals) 35 | 36 | super().process(context, candle) 37 | 38 | def serialize(self) -> Dict: 39 | obj = super().serialize() 40 | obj.update({ 41 | "strategies": [strategy.serialize() for strategy in self.strategies], 42 | "signals_executor": self.signals_executor.serialize(), 43 | }) 44 | return obj 45 | 46 | @classmethod 47 | def deserialize(cls, data: Dict) -> Optional[Processor]: 48 | strategies: List[Strategy] = [ 49 | DeserializationService.deserialize(strategy) for strategy in data.get("strategies") 50 | ] 51 | signals_executor: SignalsExecutor = DeserializationService.deserialize(data.get("signals_executor")) 52 | return cls(strategies, signals_executor, cls._deserialize_next_processor(data)) 53 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/processors/technicals.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Optional, List, Dict, Union, Tuple 4 | 5 | from algotrader.calc.technicals import TechnicalCalculator 6 | from algotrader.entities.attachments.technicals import Indicators, IndicatorValue 7 | from algotrader.entities.candle import Candle 8 | from algotrader.pipeline.configs.technical_processor_config import TechnicalsProcessorConfig 9 | from algotrader.pipeline.processor import Processor 10 | from algotrader.pipeline.processors.candle_cache import CandleCache 11 | from algotrader.pipeline.shared_context import SharedContext 12 | 13 | INDICATORS_ATTACHMENT_KEY = "indicators" 14 | TechnicalsData = Dict[str, Dict[str, List[float]]] 15 | 16 | 17 | MAX_CANDLES_FOR_CALC = 50 18 | 19 | 20 | class TechnicalsProcessor(Processor): 21 | """ 22 | Technical indicators processor. Using Tulip indicators to calculate and attach technicals to the processed candles. 23 | Make use of the SharedContext to keep track of earlier candles. 24 | """ 25 | 26 | def __init__(self, config: TechnicalsProcessorConfig, next_processor: Optional[Processor]) -> None: 27 | super().__init__(next_processor) 28 | self.config = config 29 | 30 | def process(self, context: SharedContext, candle: Candle): 31 | cache_reader = CandleCache.context_reader(context) 32 | symbol_candles = cache_reader.get_symbol_candles(candle.symbol) or [] 33 | calculator = TechnicalCalculator(symbol_candles[-MAX_CANDLES_FOR_CALC:] + [candle]) 34 | 35 | candle_indicators = Indicators() 36 | self._calculate(calculator, candle_indicators) 37 | candle.add_attachment(INDICATORS_ATTACHMENT_KEY, candle_indicators) 38 | 39 | super().process(context, candle) 40 | 41 | def _calculate(self, calculator: TechnicalCalculator, candle_indicators: Indicators): 42 | for technicalConfig in self.config.technicals: 43 | results = calculator.execute(technicalConfig.type, technicalConfig.params) 44 | candle_indicators.set(technicalConfig.name, TechnicalsProcessor._get_last_value(results)) 45 | 46 | @staticmethod 47 | def _get_last_value(values: Union[Tuple[List[float]], List[float]]) -> Optional[IndicatorValue]: 48 | if isinstance(values, tuple): 49 | return [v[-1] for v in values] 50 | elif isinstance(values, list) and values: 51 | return values[-1] 52 | 53 | def serialize(self) -> Dict: 54 | obj = super().serialize() 55 | obj.update({"config": self.config.serialize()}) 56 | return obj 57 | 58 | @classmethod 59 | def deserialize(cls, data: Dict) -> Optional[Processor]: 60 | config = TechnicalsProcessorConfig.deserialize(data["config"]) 61 | return cls(config, cls._deserialize_next_processor(data)) 62 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/processors/technicals_buckets_matcher.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Dict 2 | 3 | from algotrader.entities.attachments.technicals_buckets_matcher import IndicatorsMatchedBuckets 4 | from algotrader.entities.bucket import Bucket, BucketList 5 | from algotrader.entities.bucketscontainer import BucketsContainer 6 | from algotrader.entities.candle import Candle 7 | from algotrader.pipeline.processor import Processor 8 | from algotrader.pipeline.processors.technicals_normalizer import ( 9 | NORMALIZED_INDICATORS_ATTACHMENT_KEY, 10 | ) 11 | from algotrader.entities.attachments.technicals_normalizer import NormalizedIndicators 12 | from algotrader.pipeline.shared_context import SharedContext 13 | 14 | INDICATORS_MATCHED_BUCKETS_ATTACHMENT_KEY = "indicators_matched_buckets" 15 | 16 | 17 | class TechnicalsBucketsMatcher(Processor): 18 | """ 19 | Match technical indicators to buckets and saves them on the candle attachments objects. 20 | This processor is a companion for the TechnicalsBinner terminator which in charge of creating the bins and 21 | saving them to file. 22 | Use the TechnicalsBinner on historical data to create the bins, then run realtime date with 23 | this processor and get the matching bins. 24 | """ 25 | 26 | def __init__(self, bins_file_path: str, next_processor: Optional[Processor]) -> None: 27 | """ 28 | @param bins_file_path: path to the bins file created by TechnicalsBinner 29 | """ 30 | super().__init__(next_processor) 31 | 32 | self.bins_file_path = bins_file_path 33 | self.bins: Optional[BucketsContainer] = None 34 | 35 | def _lazy_load_bins_file(self): 36 | if not self.bins: 37 | with open(self.bins_file_path) as bins_file_content: 38 | content = bins_file_content.read() 39 | self.bins: BucketsContainer = BucketsContainer.model_validate_json(content) 40 | 41 | def process(self, context: SharedContext, candle: Candle): 42 | normalized_indicators: NormalizedIndicators = candle.get_attachment(NORMALIZED_INDICATORS_ATTACHMENT_KEY) 43 | 44 | self._lazy_load_bins_file() 45 | matched_buckets = IndicatorsMatchedBuckets() 46 | 47 | for indicator, value in normalized_indicators.items(): 48 | bins = self.bins.get(indicator) 49 | if bins: 50 | if isinstance(bins[0], list): 51 | match = self._indicator_list_match(value, bins) 52 | else: 53 | match = self._indicator_match(value, bins) 54 | 55 | matched_buckets.set(indicator, match) 56 | 57 | candle.add_attachment(INDICATORS_MATCHED_BUCKETS_ATTACHMENT_KEY, matched_buckets) 58 | 59 | super().process(context, candle) 60 | 61 | def _indicator_list_match(self, values: List[float], bins: List[BucketList]) -> List[Optional[Bucket]]: 62 | return [self._indicator_match(values[i], bins[i]) for i in range(len(values))] 63 | 64 | def _indicator_match(self, value: float, bins: BucketList) -> Optional[Bucket]: 65 | for bin in bins: 66 | if bin.get_start <= value < bin.get_end: 67 | return bin 68 | 69 | def serialize(self) -> Dict: 70 | obj = super().serialize() 71 | obj.update({"bins_file_path": self.bins_file_path}) 72 | return obj 73 | 74 | @classmethod 75 | def deserialize(cls, data: Dict) -> Optional[Processor]: 76 | return cls(data.get("bins_file_path"), cls._deserialize_next_processor(data)) 77 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/processors/timespan_change.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.event import Event 6 | from algotrader.entities.timespan import TimeSpan 7 | from algotrader.pipeline.processor import Processor 8 | from algotrader.pipeline.shared_context import SharedContext 9 | 10 | 11 | class TimeSpanChangeProcessor(Processor): 12 | """ 13 | Event emitter. 14 | Keeps track of processed candles timestamps and emits a Event.TimeSpanChange upon a TimeSpan change. 15 | """ 16 | 17 | def __init__(self, timespan: TimeSpan, next_processor: Optional[Processor]) -> None: 18 | """ 19 | @param timespan: What TimeSpan we are tracking 20 | """ 21 | super().__init__(next_processor) 22 | self.timespan = timespan 23 | self.latest_candle: Optional[Candle] = None 24 | 25 | def process(self, context: SharedContext, candle: Candle): 26 | if ( 27 | self.latest_candle 28 | and candle.time_span == self.timespan 29 | and self._is_diff(candle.timestamp, self.latest_candle.timestamp) 30 | ): 31 | self.next_processor.event(context, Event.TimeSpanChange) 32 | 33 | self.latest_candle = candle 34 | 35 | super().process(context, candle) 36 | 37 | def _is_diff(self, one: datetime, other: datetime) -> bool: 38 | if self.timespan == TimeSpan.Day: 39 | return one.date() != other.date() 40 | else: 41 | return True 42 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/reverse_source.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Dict 2 | 3 | from algotrader.entities.candle import Candle 4 | from algotrader.pipeline.source import Source 5 | 6 | 7 | class ReverseSource(Source): 8 | def __init__(self, source: Source) -> None: 9 | super().__init__() 10 | self.source = source 11 | 12 | def read(self) -> Iterator[Candle]: 13 | candles = list(self.source.read()) 14 | candles.reverse() 15 | 16 | for c in candles: 17 | yield c 18 | 19 | def serialize(self) -> Dict: 20 | obj = super().serialize() 21 | obj.update({"source": self.source.serialize()}) 22 | 23 | return obj 24 | 25 | @classmethod 26 | def deserialize(cls, data: Dict): 27 | source: Source = Source.deserialize(data["source"]) 28 | return cls(source) 29 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/runner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional, Union 3 | 4 | from algotrader.pipeline.pipeline import Pipeline 5 | from algotrader.pipeline.shared_context import SharedContext 6 | 7 | 8 | class PipelineRunner: 9 | logger = logging.getLogger("PipelineRunner") 10 | 11 | def __init__(self, pipelines: Union[Pipeline, List[Pipeline]], context: Optional[SharedContext] = None) -> None: 12 | self.pipelines: List[Pipeline] = pipelines if isinstance(pipelines, list) else [pipelines] 13 | self.context = context or SharedContext() 14 | 15 | def run(self): 16 | self.logger.info("Starting pipeline runner...") 17 | for pipeline in self.pipelines: 18 | pipeline.run(self.context) 19 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/shared_context.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, TypeVar, Generic 2 | 3 | T = TypeVar("T") 4 | 5 | 6 | class SharedContext(Generic[T]): 7 | def __init__(self) -> None: 8 | self._kv_store: Dict[str, object] = {} 9 | 10 | def put_kv_data(self, key: str, value: T): 11 | self._kv_store[key] = value 12 | 13 | def get_kv_data(self, key: str, default: object = None) -> Optional[T]: 14 | if key in self._kv_store: 15 | return self._kv_store[key] 16 | 17 | return default 18 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/source.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Iterator 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.serializable import Deserializable, Serializable 6 | 7 | 8 | class Source(Serializable, Deserializable): 9 | @abstractmethod 10 | def read(self) -> Iterator[Candle]: 11 | pass 12 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/sources/__init__.py: -------------------------------------------------------------------------------- 1 | # import algotrader.pipeline.sources.ib_history 2 | from algotrader.pipeline.sources.binance_history import BinanceHistorySource 3 | from algotrader.pipeline.sources.binance_realtime import BinanceRealtimeSource 4 | from algotrader.pipeline.sources.mongodb_source import MongoDBSource 5 | from algotrader.pipeline.sources.yahoo_finance_history import YahooFinanceHistorySource 6 | 7 | __all__ = ["BinanceHistorySource", "BinanceRealtimeSource", "MongoDBSource", "YahooFinanceHistorySource"] 8 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/sources/binance_history.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Iterator, List, Dict 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.timespan import TimeSpan 6 | from algotrader.pipeline.source import Source 7 | from algotrader.providers.binance import BinanceProvider 8 | 9 | 10 | class BinanceHistorySource(Source): 11 | def __init__( 12 | self, 13 | binance_provider: BinanceProvider, 14 | symbols: List[str], 15 | time_span: TimeSpan, 16 | start_time: datetime, 17 | end_time: datetime = datetime.now(), 18 | ): 19 | self.binance_provider = binance_provider 20 | self.symbols = symbols 21 | self.time_span = time_span 22 | self.start_time = start_time 23 | self.end_time = end_time 24 | 25 | def read(self) -> Iterator[Candle]: 26 | for symbol in self.symbols: 27 | candles = self.binance_provider.get_symbol_history(symbol, self.time_span, self.start_time, self.end_time) 28 | for candle in candles: 29 | yield candle 30 | 31 | def serialize(self) -> Dict: 32 | obj = super().serialize() 33 | obj.update({ 34 | "binanceProvider": self.binance_provider.serialize(), 35 | "symbols": self.symbols, 36 | "timeSpan": self.time_span.value, 37 | "startTime": self.start_time.timestamp(), 38 | "endTime": self.end_time.timestamp(), 39 | }) 40 | return obj 41 | 42 | @classmethod 43 | def deserialize(cls, data: Dict): 44 | provider = BinanceProvider.deserialize(data.get("binanceProvider")) 45 | return cls( 46 | provider, 47 | data.get("symbols"), 48 | TimeSpan(data.get("timeSpan")), 49 | datetime.fromtimestamp(data.get("startTime")), 50 | datetime.fromtimestamp(data.get("endTime")), 51 | ) 52 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/sources/binance_realtime.py: -------------------------------------------------------------------------------- 1 | from queue import Queue 2 | from typing import List, Dict, Iterator 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.timespan import TimeSpan 6 | from algotrader.pipeline.source import Source 7 | from algotrader.providers.binance import BinanceProvider 8 | 9 | 10 | class BinanceRealtimeSource(Source): 11 | def __init__(self, binance_provider: BinanceProvider, symbols: List[str], time_span: TimeSpan): 12 | self.binance_provider = binance_provider 13 | self.symbols = symbols 14 | self.time_span = time_span 15 | self.queue = Queue() 16 | 17 | self._last_received_candle: Dict[str, Candle] = {} 18 | 19 | def read(self) -> Iterator[Candle]: 20 | for symbol in self.symbols: 21 | self.binance_provider.start_kline_socket(symbol, self.time_span, self._on_candle) 22 | 23 | while self.binance_provider.is_socket_alive(): 24 | yield self.queue.get() 25 | 26 | def _on_candle(self, candle: Candle): 27 | if ( 28 | candle.symbol in self._last_received_candle 29 | and candle.timestamp > self._last_received_candle[candle.symbol].timestamp 30 | ): 31 | self.queue.put(self._last_received_candle[candle.symbol]) 32 | 33 | self._last_received_candle[candle.symbol] = candle 34 | 35 | def serialize(self) -> Dict: 36 | obj = super().serialize() 37 | obj.update({ 38 | "binanceProvider": self.binance_provider.serialize(), 39 | "symbols": self.symbols, 40 | "timeSpan": self.time_span.value, 41 | }) 42 | return obj 43 | 44 | @classmethod 45 | def deserialize(cls, data: Dict): 46 | provider = BinanceProvider.deserialize(data.get("binanceProvider")) 47 | return cls(provider, data.get("symbols"), TimeSpan(data.get("timeSpan"))) 48 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/sources/ib_history.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime 3 | from typing import Iterator, List, Optional 4 | 5 | from algotrader.entities.candle import Candle 6 | from algotrader.entities.timespan import TimeSpan 7 | from algotrader.market.ib_market import IBMarketProvider 8 | from algotrader.pipeline.source import Source 9 | from algotrader.providers.ib.interactive_brokers_connector import InteractiveBrokersConnector 10 | 11 | 12 | class IBHistorySource(Source): 13 | """ 14 | Source for fetching data from Interactive Brokers 15 | """ 16 | 17 | def __init__( 18 | self, 19 | ib_connector: InteractiveBrokersConnector, 20 | symbols: List[str], 21 | timespan: TimeSpan, 22 | from_time: datetime, 23 | to_time: Optional[datetime] = datetime.now(), 24 | ) -> None: 25 | """ 26 | @param ib_connector: InteractiveBrokersConnector instance 27 | @param symbols: symbols to fetch 28 | @param timespan: timespan of candles 29 | @param from_time: time to start fetching from 30 | @param to_time: time to fetch to 31 | """ 32 | self.timespan = timespan 33 | self.to_time = to_time 34 | self.from_time = from_time 35 | self.marketProvider = IBMarketProvider(ib_connector) 36 | self.symbols = symbols 37 | 38 | def read(self) -> Iterator[Candle]: 39 | for symbol in self.symbols: 40 | try: 41 | result = self.marketProvider.request_symbol_history(symbol, self.timespan, self.from_time, self.to_time) 42 | for candle in result.result(): 43 | yield candle 44 | except Exception as ex: 45 | logging.warning(f"Failed to fetch symbol {symbol} history. Error: {ex}") 46 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/sources/mongodb_source.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime 3 | from typing import Iterator, List, Optional, Dict 4 | 5 | from algotrader.entities.candle import Candle 6 | from algotrader.entities.timespan import TimeSpan 7 | from algotrader.pipeline.source import Source 8 | from algotrader.storage.mongodb_storage import MongoDBStorage 9 | 10 | 11 | class MongoDBSource(Source): 12 | """ 13 | Source for fetching data from MongoDB 14 | """ 15 | 16 | logger = logging.getLogger("MongoDBSource") 17 | 18 | def __init__( 19 | self, 20 | mongo_storage: MongoDBStorage, 21 | symbols: List[str], 22 | timespan: TimeSpan, 23 | from_time: datetime, 24 | to_time: Optional[datetime] = datetime.now(), 25 | ) -> None: 26 | """ 27 | @param mongo_storage: MongoDBStorage instance 28 | @param symbols: list of symbols to fetch 29 | @param timespan: timespan of candles 30 | @param from_time: time to start fetching from 31 | @param to_time: time to fetch to 32 | """ 33 | self.timespan = timespan 34 | self.to_time = to_time 35 | self.from_time = from_time 36 | self.mongo_storage = mongo_storage 37 | self.symbols = symbols 38 | 39 | def read(self) -> Iterator[Candle]: 40 | self.logger.info("Fetching candles from mongo source...") 41 | all_candles = self.mongo_storage.get_candles(self.timespan, self.from_time, self.to_time) 42 | self.logger.info("Got candles, starting iteration") 43 | for c in all_candles: 44 | if c.symbol in self.symbols: 45 | yield c 46 | 47 | def serialize(self) -> Dict: 48 | obj = super().serialize() 49 | obj.update({ 50 | "mongo_storage": self.mongo_storage.serialize(), 51 | "symbols": self.symbols, 52 | "timespan": self.timespan.value, 53 | "from_time": self.from_time, 54 | "to_time": self.to_time, 55 | }) 56 | return obj 57 | 58 | @classmethod 59 | def deserialize(cls, data: Dict): 60 | storage = MongoDBStorage.deserialize(data.get("mongo_storage")) 61 | return cls( 62 | storage, data.get("symbols"), TimeSpan(data.get("timespan")), data.get("from_time"), data.get("to_time") 63 | ) 64 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/sources/yahoo_finance_history.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime 3 | from typing import Iterator, List, Dict 4 | 5 | from algotrader.entities.candle import Candle 6 | from algotrader.entities.timespan import TimeSpan 7 | from algotrader.market.yahoofinance.history_provider import YahooFinanceHistoryProvider 8 | from algotrader.pipeline.source import Source 9 | 10 | 11 | class YahooFinanceHistorySource(Source): 12 | """ 13 | Source for fetching historical data from Yahoo Finance 14 | """ 15 | 16 | def __init__( 17 | self, 18 | symbols: List[str], 19 | timespan: TimeSpan, 20 | start_time: datetime, 21 | end_time: datetime, 22 | auto_adjust: bool = True, 23 | include_after_hours: bool = False, 24 | sort_all: bool = False, 25 | ): 26 | """ 27 | @param symbols: list of symbols to fetch 28 | @param timespan: candles timespan 29 | @param start_time: first candle time 30 | @param end_time: latest candle time 31 | @param auto_adjust: auto adjust closing price (dividends, splits) 32 | @param include_after_hours: include pre and post market data 33 | @param sort_all: sort candles by time cross symbols (will start streaming candles only after all 34 | symbols fetched (slower for first response) 35 | """ 36 | self.sort_all = sort_all 37 | self.symbols = symbols 38 | self.auto_adjust = auto_adjust 39 | self.include_after_hours = include_after_hours 40 | self.end_time = end_time 41 | self.start_time = start_time 42 | self.timespan = timespan 43 | self.provider = YahooFinanceHistoryProvider() 44 | 45 | def fetch_symbol(self, symbol: str) -> Iterator[Candle]: 46 | logging.info(f"Fetching {symbol} history from Yahoo Finance") 47 | candles = self.provider.get_symbol_history( 48 | symbol, 49 | self.timespan, 50 | self.timespan, 51 | self.start_time, 52 | self.end_time, 53 | self.auto_adjust, 54 | self.include_after_hours, 55 | ) 56 | for candle in candles: 57 | yield candle 58 | 59 | def _read_quick(self) -> Iterator[Candle]: 60 | for symbol in self.symbols: 61 | for candle in self.fetch_symbol(symbol): 62 | yield candle 63 | 64 | def _read_sort(self) -> List[Candle]: 65 | candles: List[Candle] = [] 66 | for symbol in self.symbols: 67 | candles.extend(self.fetch_symbol(symbol)) 68 | 69 | return sorted(candles, key=lambda candle: candle.timestamp) 70 | 71 | def read(self) -> Iterator[Candle]: 72 | if self.sort_all: 73 | for candle in self._read_sort(): 74 | yield candle 75 | else: 76 | for candle in self._read_quick(): 77 | yield candle 78 | 79 | def serialize(self) -> Dict: 80 | obj = super().serialize() 81 | obj.update({ 82 | "symbols": self.symbols, 83 | "timespan": self.timespan.value, 84 | "start_time": self.start_time.timestamp(), 85 | "end_time": self.end_time.timestamp(), 86 | "auto_adjust": self.auto_adjust, 87 | "include_after_hours": self.include_after_hours, 88 | "sort_all": self.sort_all, 89 | }) 90 | return obj 91 | 92 | @classmethod 93 | def deserialize(cls, data: Dict): 94 | return cls( 95 | data.get("symbols"), 96 | TimeSpan(data.get("timespan")), 97 | datetime.fromtimestamp(data.get("start_time")), 98 | datetime.fromtimestamp(data.get("end_time")), 99 | data.get("auto_adjust"), 100 | data.get("include_after_hours"), 101 | data.get("sort_all"), 102 | ) 103 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/strategies/__init__.py: -------------------------------------------------------------------------------- 1 | from algotrader.pipeline.strategies.connors_rsi2 import ConnorsRSI2 2 | from algotrader.pipeline.strategies.history_bucket_compare import HistoryBucketCompareStrategy 3 | from algotrader.pipeline.strategies.history_cosine_similarity import HistoryCosineSimilarityStrategy 4 | from algotrader.pipeline.strategies.simple_sma import SimpleSMA 5 | 6 | __all__ = ["SimpleSMA", "ConnorsRSI2", "HistoryBucketCompareStrategy", "HistoryCosineSimilarityStrategy"] 7 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/strategies/connors_rsi2.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict 2 | 3 | from algotrader.entities.candle import Candle 4 | from algotrader.entities.strategy import Strategy 5 | from algotrader.entities.strategy_signal import StrategySignal, SignalDirection 6 | from algotrader.pipeline.processors.candle_cache import CandleCache 7 | from algotrader.pipeline.processors.technicals import INDICATORS_ATTACHMENT_KEY 8 | from algotrader.entities.attachments.technicals import Indicators 9 | from algotrader.pipeline.shared_context import SharedContext 10 | 11 | 12 | class ConnorsRSI2(Strategy): 13 | def __init__(self) -> None: 14 | super().__init__() 15 | self.current_position: Dict[str, Optional[SignalDirection]] = {} 16 | 17 | def process(self, context: SharedContext, candle: Candle) -> List[StrategySignal]: 18 | cache_reader = CandleCache.context_reader(context) 19 | symbol_candles = cache_reader.get_symbol_candles(candle.symbol) 20 | 21 | if not symbol_candles or len(symbol_candles) < 1: 22 | return [] 23 | 24 | if candle.symbol not in self.current_position: 25 | self.current_position[candle.symbol] = None 26 | 27 | past_candle_indicators: Indicators = symbol_candles[-1].get_attachment(INDICATORS_ATTACHMENT_KEY) 28 | current_candle_indicators: Indicators = candle.get_attachment(INDICATORS_ATTACHMENT_KEY) 29 | 30 | if ( 31 | not current_candle_indicators.has("rsi2") 32 | or not current_candle_indicators.has("sma50") 33 | or not past_candle_indicators.has("rsi2") 34 | ): 35 | return [] 36 | 37 | if self.current_position[candle.symbol] == SignalDirection.Long: 38 | if candle.close > current_candle_indicators["sma5"]: 39 | self.current_position[candle.symbol] = None 40 | return [StrategySignal(candle.symbol, SignalDirection.Short)] 41 | 42 | return [] 43 | 44 | if self.current_position[candle.symbol] == SignalDirection.Short: 45 | if candle.close < current_candle_indicators["sma5"]: 46 | self.current_position[candle.symbol] = None 47 | return [StrategySignal(candle.symbol, SignalDirection.Long)] 48 | 49 | return [] 50 | 51 | if ( 52 | candle.close > current_candle_indicators["sma50"] 53 | and current_candle_indicators["rsi2"] < 10 < past_candle_indicators["rsi2"] 54 | ): 55 | self.current_position[candle.symbol] = SignalDirection.Long 56 | return [StrategySignal(candle.symbol, SignalDirection.Long)] 57 | 58 | if ( 59 | candle.close < current_candle_indicators["sma50"] 60 | and current_candle_indicators["rsi2"] > 90 > past_candle_indicators["rsi2"] 61 | ): 62 | self.current_position[candle.symbol] = SignalDirection.Short 63 | return [StrategySignal(candle.symbol, SignalDirection.Short)] 64 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/strategies/history_cosine_similarity.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Dict 3 | 4 | from scipy import spatial 5 | 6 | from algotrader.entities.candle import Candle 7 | from algotrader.entities.strategy import Strategy 8 | from algotrader.entities.strategy_signal import StrategySignal, SignalDirection 9 | from algotrader.pipeline.processors.technicals_buckets_matcher import ( 10 | INDICATORS_MATCHED_BUCKETS_ATTACHMENT_KEY, 11 | ) 12 | from algotrader.entities.attachments.technicals_buckets_matcher import IndicatorsMatchedBuckets 13 | from algotrader.pipeline.shared_context import SharedContext 14 | from algotrader.serialization.store import DeserializationService 15 | from algotrader.storage.storage_provider import StorageProvider 16 | 17 | 18 | class HistoryCosineSimilarityStrategy(Strategy): 19 | def __init__( 20 | self, 21 | storage_provider: StorageProvider, 22 | timeframe_start: datetime, 23 | timeframe_end: datetime, 24 | indicators_to_compare: List[str], 25 | return_field: str, 26 | min_event_count: int, 27 | min_avg_return: float, 28 | ) -> None: 29 | self.timeframe_start = timeframe_start 30 | self.timeframe_end = timeframe_end 31 | self.storage_provider = storage_provider 32 | self.indicators_to_compare = indicators_to_compare 33 | self.return_field = return_field 34 | self.min_event_count = min_event_count 35 | self.min_avg_return = min_avg_return 36 | 37 | groupby_fields = [f"attachments.indicators_matched_buckets.{ind}.ident" for ind in self.indicators_to_compare] 38 | return_fields = [f"attachments.returns.ctc-{i}" for i in range(1, 20)] 39 | 40 | self.long_matchers, self.short_matchers = storage_provider.get_aggregated_history( 41 | timeframe_start, timeframe_end, groupby_fields, return_fields, min_event_count, min_avg_return 42 | ) 43 | 44 | def process(self, context: SharedContext, candle: Candle) -> List[StrategySignal]: 45 | indicators_buckets: IndicatorsMatchedBuckets = candle.get_attachment(INDICATORS_MATCHED_BUCKETS_ATTACHMENT_KEY) 46 | 47 | candle_values: list[int] = [] 48 | for indicator in self.indicators_to_compare: 49 | if not indicators_buckets.has(indicator): 50 | return [] 51 | 52 | candle_values.append(indicators_buckets.get(indicator).ident) 53 | 54 | for matcher in self.long_matchers: 55 | matcher_values: list[int] = [] 56 | for indicator in self.indicators_to_compare: 57 | matcher_values.append(matcher[f"attachments.indicators_matched_buckets.{indicator}.ident"]) 58 | 59 | result = 1 - spatial.distance.cosine(candle_values, matcher_values) 60 | if result > 0.997: 61 | return [StrategySignal(candle.symbol, SignalDirection.Long)] 62 | 63 | return [] 64 | 65 | def serialize(self) -> Dict: 66 | obj = super().serialize() 67 | obj.update({ 68 | "storage_provider": self.storage_provider.serialize(), 69 | "timeframe_start": self.timeframe_start, 70 | "timeframe_end": self.timeframe_end, 71 | "indicators_to_compare": self.indicators_to_compare, 72 | "return_field": self.return_field, 73 | "min_event_count": self.min_event_count, 74 | "min_avg_return": self.min_avg_return, 75 | }) 76 | return obj 77 | 78 | @classmethod 79 | def deserialize(cls, data: Dict): 80 | storage_provider: StorageProvider = DeserializationService.deserialize(data.get("storage_provider")) 81 | 82 | return cls( 83 | storage_provider, 84 | data.get("timeframe_start", datetime.now()), 85 | data.get("timeframe_end", datetime.now()), 86 | data.get("indicators_to_compare", []), 87 | data.get("return_field", ""), 88 | data.get("min_event_count", 0), 89 | data.get("min_avg_return", 0), 90 | ) 91 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/strategies/simple_sma.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from algotrader.entities.candle import Candle 4 | from algotrader.entities.strategy import Strategy 5 | from algotrader.entities.strategy_signal import StrategySignal, SignalDirection 6 | from algotrader.pipeline.processors.candle_cache import CandleCache 7 | from algotrader.pipeline.processors.technicals import INDICATORS_ATTACHMENT_KEY 8 | from algotrader.entities.attachments.technicals import Indicators 9 | from algotrader.pipeline.shared_context import SharedContext 10 | 11 | 12 | class SimpleSMA(Strategy): 13 | """ 14 | Simple Moving average strategy 15 | """ 16 | 17 | def process(self, context: SharedContext, candle: Candle) -> List[StrategySignal]: 18 | cache_reader = CandleCache.context_reader(context) 19 | symbol_candles = cache_reader.get_symbol_candles(candle.symbol) 20 | 21 | if not symbol_candles or len(symbol_candles) < 1: 22 | return [] 23 | 24 | past_candle_indicators: Indicators = symbol_candles[-1].get_attachment(INDICATORS_ATTACHMENT_KEY) 25 | current_candle_indicators: Indicators = candle.get_attachment(INDICATORS_ATTACHMENT_KEY) 26 | 27 | if not current_candle_indicators.has("sma20") or not past_candle_indicators.has("sma20"): 28 | return [] 29 | 30 | if ( 31 | current_candle_indicators["sma5"] > current_candle_indicators["sma20"] 32 | and past_candle_indicators["sma5"] < past_candle_indicators["sma20"] 33 | ): 34 | return [StrategySignal(candle.symbol, SignalDirection.Long)] 35 | 36 | if ( 37 | current_candle_indicators["sma5"] < current_candle_indicators["sma20"] 38 | and past_candle_indicators["sma5"] > past_candle_indicators["sma20"] 39 | ): 40 | return [StrategySignal(candle.symbol, SignalDirection.Short)] 41 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/terminator.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | from algotrader.entities.serializable import Deserializable, Serializable 4 | from algotrader.pipeline.shared_context import SharedContext 5 | 6 | 7 | class Terminator(Serializable, Deserializable): 8 | @abstractmethod 9 | def terminate(self, context: SharedContext): 10 | pass 11 | -------------------------------------------------------------------------------- /src/algotrader/pipeline/terminators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/pipeline/terminators/__init__.py -------------------------------------------------------------------------------- /src/algotrader/pipeline/terminators/technicals_binner.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from typing import List, Dict, Tuple 5 | 6 | from algotrader.entities.attachments.technicals import IndicatorValue 7 | from algotrader.entities.bucket import Bucket 8 | from algotrader.entities.bucketscontainer import BucketsContainer 9 | from algotrader.entities.candle import Candle 10 | from algotrader.pipeline.processors.candle_cache import CandleCache 11 | from algotrader.pipeline.processors.technicals_normalizer import ( 12 | NORMALIZED_INDICATORS_ATTACHMENT_KEY, 13 | ) 14 | from algotrader.entities.attachments.technicals_normalizer import NormalizedIndicators 15 | from algotrader.pipeline.shared_context import SharedContext 16 | from algotrader.pipeline.terminator import Terminator 17 | 18 | 19 | class TechnicalsBinner(Terminator): 20 | def __init__( 21 | self, symbols: List[str], bins_count: int, output_file_path: str, outliers_removal_percentage: float = 0.05 22 | ) -> None: 23 | super().__init__() 24 | self.outliers_removal_percentage = outliers_removal_percentage 25 | self.symbols = symbols 26 | self.output_file_path = output_file_path 27 | self.values: Dict[str, List[IndicatorValue]] = {} 28 | self.bins = BucketsContainer() 29 | self.bins_count = bins_count 30 | 31 | def terminate(self, context: SharedContext): 32 | cache_reader = CandleCache.context_reader(context) 33 | 34 | for symbol in self.symbols: 35 | symbol_candles = cache_reader.get_symbol_candles(symbol) or [] 36 | for candle in symbol_candles: 37 | self._process_candle(candle) 38 | 39 | self._calculate_bins() 40 | self._save_bins() 41 | 42 | def _process_candle(self, candle: Candle): 43 | normalized_indicators: NormalizedIndicators = candle.get_attachment(NORMALIZED_INDICATORS_ATTACHMENT_KEY) 44 | 45 | if not normalized_indicators: 46 | return 47 | 48 | for indicator, value in normalized_indicators.items(): 49 | if indicator not in self.values: 50 | self.values[indicator] = [] 51 | 52 | self.values[indicator].append(value) 53 | 54 | def _calculate_bins(self): 55 | for indicator, values in self.values.items(): 56 | if isinstance(values[0], float): 57 | self.bins.add(indicator, self._get_single_float_bins(values)) 58 | elif isinstance(values[0], list) or isinstance(values[0], Tuple): 59 | list_size = len(values[0]) 60 | bins: List[List[Bucket]] = [] 61 | for i in range(list_size): 62 | bins.append(self._get_single_float_bins([v[i] for v in values])) 63 | 64 | self.bins.add(indicator, bins) 65 | 66 | def _get_single_float_bins(self, values: List[float]) -> List[Bucket]: 67 | values.sort() 68 | 69 | margins = int(len(values) * self.outliers_removal_percentage) 70 | values = values[margins : len(values) - margins] 71 | 72 | step_size = int(math.floor(len(values) / self.bins_count)) 73 | 74 | bins: List[Bucket] = [Bucket(ident=0, end=values[0])] 75 | 76 | for i in range(0, len(values), step_size): 77 | bins.append(Bucket(ident=len(bins), start=values[i], end=values[min(i + step_size, len(values) - 1)])) 78 | 79 | bins.append(Bucket(ident=len(bins), start=values[len(values) - 1])) 80 | 81 | return bins 82 | 83 | def _save_bins(self): 84 | with open(self.output_file_path, "w+") as output_file: 85 | output_file.write(self.bins.model_dump_json()) 86 | 87 | def serialize(self) -> Dict: 88 | obj = super().serialize() 89 | obj.update({ 90 | "symbols": self.symbols, 91 | "bins_count": self.bins_count, 92 | "output_file_path": self.output_file_path, 93 | "outliers_removal_percentage": self.outliers_removal_percentage, 94 | }) 95 | return obj 96 | 97 | @classmethod 98 | def deserialize(cls, data: Dict): 99 | return cls( 100 | data.get("symbols", []), 101 | data.get("bins_count", 0), 102 | data.get("output_file_path", ""), 103 | data.get("outliers_removal_percentage", 0.05), 104 | ) 105 | -------------------------------------------------------------------------------- /src/algotrader/providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/providers/__init__.py -------------------------------------------------------------------------------- /src/algotrader/providers/ib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/providers/ib/__init__.py -------------------------------------------------------------------------------- /src/algotrader/providers/ib/ib_interval.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from algotrader.entities.timespan import TimeSpan 4 | 5 | 6 | def datetime_to_api_string(d: datetime) -> str: 7 | return d.strftime("%Y%m%d %H:%M:%S") 8 | 9 | 10 | def timespan_to_api_str(timespan: TimeSpan) -> str: 11 | if timespan == TimeSpan.Day: 12 | return "1 day" 13 | elif timespan == TimeSpan.Minute: 14 | return "1 min" 15 | else: 16 | raise Exception("data provider does not support this timespan") 17 | -------------------------------------------------------------------------------- /src/algotrader/providers/ib/query_subscription.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from typing import List 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.timespan import TimeSpan 6 | 7 | 8 | class QuerySubscription: 9 | def __init__(self, query_id: int, symbol: str, candle_timespan: TimeSpan) -> None: 10 | self.candle_timespan = candle_timespan 11 | self.symbol = symbol 12 | self.query_id = query_id 13 | self.done_event = threading.Event() 14 | self.candles: List[Candle] = [] 15 | self.is_error = False 16 | 17 | def push_candles(self, candles: List[Candle]): 18 | self.candles += candles 19 | 20 | def done(self, is_error: bool = False): 21 | self.is_error = is_error 22 | self.done_event.set() 23 | 24 | def result(self) -> List[Candle]: 25 | self.done_event.wait() 26 | 27 | if self.is_error: 28 | raise Exception("query failed. see logs.") 29 | 30 | return self.candles 31 | -------------------------------------------------------------------------------- /src/algotrader/serialization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/serialization/__init__.py -------------------------------------------------------------------------------- /src/algotrader/serialization/store.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import importlib 4 | from typing import Dict, Optional, TypeVar, Generic 5 | 6 | # from typing import TYPE_CHECKING 7 | 8 | # if TYPE_CHECKING: 9 | from algotrader.entities.serializable import Deserializable 10 | 11 | T = TypeVar("T", bound=Deserializable) 12 | 13 | 14 | class DeserializationService(Generic[T]): 15 | @staticmethod 16 | def deserialize(data: Dict) -> Optional[T]: 17 | if data is None or data.get("__class__") is None: 18 | return None 19 | 20 | class_name = data.get("__class__", "") 21 | mod_name, cls_name = class_name.split(":") 22 | mod = importlib.import_module(mod_name) 23 | cls: Deserializable = getattr(mod, cls_name) 24 | return cls.deserialize(data) 25 | -------------------------------------------------------------------------------- /src/algotrader/storage/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/storage/__init__.py -------------------------------------------------------------------------------- /src/algotrader/storage/inmemory_storage.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Dict, Iterator 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.timespan import TimeSpan 6 | from algotrader.storage.storage_provider import StorageProvider 7 | 8 | 9 | class InMemoryStorage(StorageProvider): 10 | def __init__(self) -> None: 11 | super().__init__() 12 | self.candles: Dict[str, List[Candle]] = {} 13 | 14 | def get_symbol_candles( 15 | self, symbol: str, time_span: TimeSpan, from_timestamp: datetime, to_timestamp: datetime, limit: int = 0 16 | ) -> List[Candle]: 17 | if symbol not in self.candles: 18 | return [] 19 | 20 | results = list( 21 | filter( 22 | lambda candle: candle.time_span == time_span and from_timestamp <= candle.timestamp <= to_timestamp, 23 | self.candles[symbol], 24 | ) 25 | ) 26 | 27 | if limit > 0: 28 | return results[:limit] 29 | 30 | return results 31 | 32 | def get_candles(self, time_span: TimeSpan, from_timestamp: datetime, to_timestamp: datetime) -> List[Candle]: 33 | def all_candles() -> Iterator[Candle]: 34 | for sym_candles in self.candles.values(): 35 | for c in sym_candles: 36 | yield c 37 | 38 | return list( 39 | filter( 40 | lambda candle: candle.time_span == time_span and from_timestamp <= candle.timestamp <= to_timestamp, 41 | all_candles(), 42 | ) 43 | ) 44 | 45 | def save(self, candle: Candle): 46 | if candle.symbol not in self.candles: 47 | self.candles[candle.symbol] = [] 48 | 49 | self.candles[candle.symbol].append(candle) 50 | self.candles[candle.symbol].sort(key=lambda c: c.timestamp) 51 | -------------------------------------------------------------------------------- /src/algotrader/storage/storage_provider.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from datetime import datetime 3 | from typing import List, Dict, Tuple 4 | 5 | from algotrader.entities.candle import Candle 6 | from algotrader.entities.serializable import Deserializable, Serializable 7 | from algotrader.entities.timespan import TimeSpan 8 | 9 | 10 | class StorageProvider(Serializable, Deserializable): 11 | @abstractmethod 12 | def save(self, candle: Candle): 13 | pass 14 | 15 | @abstractmethod 16 | def get_symbol_candles( 17 | self, symbol: str, time_span: TimeSpan, from_timestamp: datetime, to_timestamp: datetime, limit: int 18 | ) -> List[Candle]: 19 | pass 20 | 21 | @abstractmethod 22 | def get_candles(self, time_span: TimeSpan, from_timestamp: datetime, to_timestamp: datetime) -> List[Candle]: 23 | pass 24 | 25 | @abstractmethod 26 | def get_aggregated_history( 27 | self, 28 | from_timestamp: datetime, 29 | to_timestamp: datetime, 30 | groupby_fields: List[str], 31 | return_fields: List[str], 32 | min_count: int, 33 | min_return: float, 34 | ) -> Tuple[List[Dict[str, int]], List[Dict[str, int]]]: 35 | pass 36 | -------------------------------------------------------------------------------- /src/algotrader/trade/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/src/algotrader/trade/__init__.py -------------------------------------------------------------------------------- /src/algotrader/trade/signals_executor.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import List 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.serializable import Deserializable, Serializable 6 | from algotrader.entities.strategy_signal import StrategySignal 7 | 8 | 9 | class SignalsExecutor(Serializable, Deserializable): 10 | @abstractmethod 11 | def execute(self, candle: Candle, signals: List[StrategySignal]): 12 | pass 13 | -------------------------------------------------------------------------------- /src/algotrader/trade/simple_sum_signals_executor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Dict 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.strategy_signal import StrategySignal, SignalDirection 6 | from algotrader.trade.signals_executor import SignalsExecutor 7 | 8 | DEFAULT_ORDER_VALUE = 10000 9 | 10 | 11 | class SimpleSumSignalsExecutor(SignalsExecutor): 12 | def __init__(self) -> None: 13 | self.position: Dict[str, float] = {} 14 | self.cash = 0 15 | 16 | def _get_order_size(self, price: float) -> float: 17 | return DEFAULT_ORDER_VALUE / price 18 | 19 | def execute(self, candle: Candle, signals: List[StrategySignal]): 20 | # close when there is no signal 21 | if len(signals) == 0 and candle.symbol in self.position and self.position[candle.symbol] != 0: 22 | self.cash += candle.close * self.position[candle.symbol] 23 | self.position[candle.symbol] = 0 24 | 25 | for signal in signals: 26 | logging.info(f"Got {signal.direction} signal for {signal.symbol}. Signaling candle: {candle.model_dump()}") 27 | 28 | if signal.symbol not in self.position: 29 | self.position[signal.symbol] = 0 30 | 31 | # don't act if we already have a position 32 | if self.position[signal.symbol] != 0: 33 | continue 34 | 35 | order_size = self._get_order_size(candle.close) 36 | 37 | if signal.direction == SignalDirection.Long: 38 | self.position[signal.symbol] += order_size 39 | self.cash -= candle.close * order_size 40 | else: 41 | self.position[signal.symbol] -= order_size 42 | self.cash += candle.close * order_size 43 | 44 | non_zero_postitions = {k: v for k, v in self.position.items() if v > 0} 45 | if len(non_zero_postitions) > 0: 46 | logging.info(f"Position: {non_zero_postitions} | Cash: {self.cash}") 47 | -------------------------------------------------------------------------------- /src/algotrader/trade/stdout_signals_executor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.strategy_signal import StrategySignal 6 | from algotrader.trade.signals_executor import SignalsExecutor 7 | 8 | 9 | class StdoutSignalsExecutor(SignalsExecutor): 10 | def execute(self, candle: Candle, signals: List[StrategySignal]): 11 | for signal in signals: 12 | logging.info(f"Got {signal.direction} signal for {signal.symbol}. Signaling candle: {candle.model_dump()}") 13 | -------------------------------------------------------------------------------- /tests/configs/correlations.json: -------------------------------------------------------------------------------- 1 | { 2 | "groups": [ 3 | [ 4 | "X", 5 | "Y", 6 | "Z" 7 | ] 8 | ] 9 | } -------------------------------------------------------------------------------- /tests/fakes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/tests/fakes/__init__.py -------------------------------------------------------------------------------- /tests/fakes/pipeline_validators.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from algotrader.entities.candle import Candle 4 | from algotrader.entities.event import Event 5 | from algotrader.pipeline.processor import Processor 6 | from algotrader.pipeline.shared_context import SharedContext 7 | from algotrader.pipeline.terminator import Terminator 8 | 9 | 10 | class ValidationProcessor(Processor): 11 | def __init__( 12 | self, 13 | process_callback: Callable[[SharedContext, Candle], None], 14 | event_callback: Callable[[SharedContext, Event], None] = None, 15 | ) -> None: 16 | super().__init__(None) 17 | self.process_callback = process_callback 18 | self.event_callback = event_callback 19 | 20 | def process(self, context: SharedContext, candle: Candle): 21 | self.process_callback(context, candle) 22 | 23 | def event(self, context: SharedContext, event: Event): 24 | if self.event_callback: 25 | self.event_callback(context, event) 26 | 27 | 28 | class TerminatorValidator(Terminator): 29 | def __init__(self, callback: Callable[[SharedContext], None]) -> None: 30 | self.callback = callback 31 | 32 | def terminate(self, context: SharedContext): 33 | self.callback(context) 34 | -------------------------------------------------------------------------------- /tests/fakes/source.py: -------------------------------------------------------------------------------- 1 | from typing import List, Iterator 2 | 3 | from algotrader.entities.candle import Candle 4 | from algotrader.pipeline.source import Source 5 | 6 | 7 | class FakeSource(Source): 8 | def __init__(self, candles: List[Candle]) -> None: 9 | super().__init__() 10 | self.candles = candles 11 | self.candles.sort(key=lambda c: c.timestamp) 12 | 13 | def read(self) -> Iterator[Candle]: 14 | for c in self.candles: 15 | yield c 16 | -------------------------------------------------------------------------------- /tests/fakes/strategy_executor.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List 2 | 3 | from algotrader.entities.candle import Candle 4 | from algotrader.entities.strategy_signal import StrategySignal 5 | from algotrader.trade.signals_executor import SignalsExecutor 6 | 7 | ExecuterCallback = Callable[[List[StrategySignal]], None] 8 | 9 | 10 | class FakeSignalsExecutor(SignalsExecutor): 11 | def __init__(self, callback: ExecuterCallback) -> None: 12 | super().__init__() 13 | self.callback = callback 14 | 15 | def execute(self, candle: Candle, signals: List[StrategySignal]): 16 | self.callback(signals) 17 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/test_binance_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | from datetime import datetime, timedelta 4 | from unittest import TestCase 5 | 6 | from dotenv import load_dotenv 7 | 8 | from algotrader.entities.candle import Candle 9 | from algotrader.entities.timespan import TimeSpan 10 | from algotrader.providers.binance import BinanceProvider 11 | 12 | load_dotenv() 13 | 14 | 15 | class TestBinanceMarketProvider(TestCase): 16 | SYMBOL = "BTCUSDT" 17 | BASE_TIME = datetime.fromtimestamp(1671183359) 18 | API_KEY = os.environ.get("BINANCE_API_KEY") 19 | API_SECRET = os.environ.get("BINANCE_API_SECRET") 20 | 21 | def test_get_account(self): 22 | provider = BinanceProvider(self.API_KEY, self.API_SECRET, False, testnet=True) 23 | provider.account() 24 | 25 | def test_get_symbol_history(self): 26 | provider = BinanceProvider(self.API_KEY, self.API_SECRET, False) 27 | 28 | from_time = self.BASE_TIME - timedelta(days=50) 29 | to_time = self.BASE_TIME 30 | candles = provider.get_symbol_history(self.SYMBOL, TimeSpan.Day, from_time, to_time) 31 | 32 | self.assertEqual(len(candles), 50) 33 | for candle in candles: 34 | self.assertTrue(from_time <= candle.timestamp <= to_time) 35 | self._assert_candles_values(candle) 36 | 37 | def test_get_kline_stream(self): 38 | provider = BinanceProvider(self.API_KEY, self.API_SECRET, True) 39 | event: threading.Event = threading.Event() 40 | 41 | def handler(candle: Candle): 42 | self._assert_candles_values(candle) 43 | event.set() 44 | 45 | provider.start_kline_socket(self.SYMBOL, TimeSpan.Second, handler) 46 | event.wait() 47 | 48 | def _assert_candles_values(self, candle: Candle): 49 | self.assertEqual(candle.symbol, self.SYMBOL) 50 | self.assertTrue(candle.high > candle.low) 51 | self.assertTrue(candle.volume > 0) 52 | -------------------------------------------------------------------------------- /tests/integration/test_ib_provider.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from unittest import TestCase 3 | 4 | from algotrader.entities.timespan import TimeSpan 5 | from algotrader.market.async_market_provider import AsyncQueryResult 6 | from algotrader.market.ib_market import IBMarketProvider 7 | from algotrader.providers.ib.interactive_brokers_connector import InteractiveBrokersConnector 8 | 9 | 10 | class TestIBMarketProvider(TestCase): 11 | def setUp(self) -> None: 12 | super().setUp() 13 | self.ib_connector = InteractiveBrokersConnector() 14 | 15 | def tearDown(self) -> None: 16 | super().tearDown() 17 | self.ib_connector.kill() 18 | 19 | def test_daily_history(self): 20 | ib_provider = IBMarketProvider(self.ib_connector) 21 | from_time = datetime.now() - timedelta(days=50) 22 | to_time = datetime.now() - timedelta(days=30) 23 | 24 | async_result: AsyncQueryResult = ib_provider.request_symbol_history("AAPL", TimeSpan.Day, from_time, to_time) 25 | candles = async_result.result() 26 | self.assertTrue(len(candles) > 10) 27 | self.assertTrue(candles[0].timestamp < candles[-1].timestamp) 28 | 29 | def test_current_day_history(self): 30 | ib_provider = IBMarketProvider(self.ib_connector) 31 | 32 | yesterday = (datetime.now() - timedelta(days=1)).date() 33 | yesterday = datetime(yesterday.year, yesterday.month, yesterday.day) 34 | async_result: AsyncQueryResult = ib_provider.request_symbol_history( 35 | "AAPL", TimeSpan.Day, yesterday, datetime.now() 36 | ) 37 | candles = async_result.result() 38 | self.assertEqual(1, len(candles)) 39 | self.assertEqual(yesterday.date(), candles[0].timestamp.date()) 40 | 41 | def test_yearly_history(self): 42 | ib_provider = IBMarketProvider(self.ib_connector) 43 | from_time = datetime.now() - timedelta(days=500) 44 | to_time = datetime.now() - timedelta(days=100) 45 | 46 | async_result: AsyncQueryResult = ib_provider.request_symbol_history("AAPL", TimeSpan.Day, from_time, to_time) 47 | candles = async_result.result() 48 | self.assertTrue(len(candles) > 10) 49 | self.assertTrue(candles[0].timestamp < candles[-1].timestamp) 50 | self.assertTrue(candles[0].timestamp.year < candles[-1].timestamp.year) 51 | 52 | self.assertEqual(candles[0].timestamp.year, from_time.year) 53 | self.assertEqual(candles[0].timestamp.month, from_time.month) 54 | 55 | self.assertEqual(candles[-1].timestamp.year, to_time.year) 56 | self.assertEqual(candles[-1].timestamp.month, to_time.month) 57 | -------------------------------------------------------------------------------- /tests/integration/test_ib_source.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from unittest import TestCase 3 | 4 | from algotrader.entities.timespan import TimeSpan 5 | from algotrader.pipeline.sources.ib_history import IBHistorySource 6 | from algotrader.providers.ib.interactive_brokers_connector import InteractiveBrokersConnector 7 | 8 | 9 | class TestIBSource(TestCase): 10 | def setUp(self) -> None: 11 | super().setUp() 12 | self.ib_connector = InteractiveBrokersConnector() 13 | 14 | def tearDown(self) -> None: 15 | super().tearDown() 16 | self.ib_connector.kill() 17 | 18 | def test(self): 19 | symbol = "AAPL" 20 | from_time = datetime.now() - timedelta(days=30) 21 | source = IBHistorySource(self.ib_connector, [symbol], TimeSpan.Day, from_time) 22 | 23 | candles = list(source.read()) 24 | self.assertTrue(len(candles) > 10) 25 | 26 | for candle in candles: 27 | self.assertEqual(symbol, candle.symbol) 28 | self.assertEqual(TimeSpan.Day, candle.time_span) 29 | self.assertTrue(candle.timestamp < datetime.now()) 30 | -------------------------------------------------------------------------------- /tests/integration/test_yahoo_provider.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from unittest import TestCase 3 | 4 | from algotrader.entities.timespan import TimeSpan 5 | from algotrader.market.yahoofinance.history_provider import YahooFinanceHistoryProvider 6 | 7 | 8 | class TestYahooMarketProvider(TestCase): 9 | def test_get_symbol_history(self): 10 | from_time = datetime.now() - timedelta(days=50) 11 | to_time = datetime.now() 12 | provider = YahooFinanceHistoryProvider() 13 | result = provider.get_symbol_history("AAPL", TimeSpan.Day, TimeSpan.Day, from_time, to_time) 14 | self.assertTrue(len(result) > 10) 15 | self.assertIsNotNone(result) 16 | for candle in result: 17 | self.assertEqual("AAPL", candle.symbol) 18 | self.assertTrue(0 < candle.open) 19 | self.assertTrue(0 < candle.close) 20 | self.assertTrue(0 < candle.volume) 21 | self.assertTrue(to_time.date() >= candle.timestamp.date() >= from_time.date()) 22 | self.assertTrue(0 < candle.low < candle.high) 23 | -------------------------------------------------------------------------------- /tests/integration/test_yahoo_source.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from typing import List 3 | from unittest import TestCase 4 | 5 | from algotrader.entities.candle import Candle 6 | from algotrader.entities.timespan import TimeSpan 7 | from algotrader.market.yahoofinance.history_provider import YahooFinanceHistoryProvider 8 | from algotrader.pipeline.sources.yahoo_finance_history import YahooFinanceHistorySource 9 | 10 | 11 | class TestYahooMarketSource(TestCase): 12 | provider = YahooFinanceHistoryProvider() 13 | symbols = ["AAPL", "MSFT"] 14 | to_time = datetime.fromtimestamp(1669145312) 15 | from_time = to_time - timedelta(days=50) 16 | 17 | def test_quick_source(self): 18 | source = YahooFinanceHistorySource(self.symbols, TimeSpan.Day, self.from_time, self.to_time) 19 | candles = list(source.read()) 20 | self._assert_sanity_response(candles) 21 | 22 | def test_sorted_source(self): 23 | source = YahooFinanceHistorySource(self.symbols, TimeSpan.Day, self.from_time, self.to_time, sort_all=True) 24 | candles = list(source.read()) 25 | self._assert_sanity_response(candles) 26 | 27 | for i in range(1, len(candles)): 28 | self.assertTrue(candles[i].timestamp >= candles[i - 1].timestamp) 29 | 30 | def _assert_sanity_response(self, candles: List[Candle]): 31 | self.assertEqual(len(candles) % 2, 0) 32 | 33 | for candle in candles: 34 | self.assertIn(candle.symbol, self.symbols) 35 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.timespan import TimeSpan 6 | 7 | TEST_SYMBOL = "X" 8 | 9 | 10 | def generate_candle_with_symbol(symbol: str, time_span: TimeSpan, timestamp: datetime) -> Candle: 11 | return Candle( 12 | symbol=symbol, time_span=time_span, timestamp=timestamp, open=0.0, close=0.0, high=0.0, low=0.0, volume=0.0 13 | ) 14 | 15 | 16 | def generate_candle(time_span: TimeSpan, timestamp: datetime) -> Candle: 17 | return generate_candle_with_symbol(TEST_SYMBOL, time_span, timestamp) 18 | 19 | 20 | def generate_candle_with_price(time_span: TimeSpan, timestamp: datetime, price: float) -> Candle: 21 | candle = generate_candle(time_span, timestamp) 22 | candle.open = candle.close = candle.high = candle.low = candle.volume = price 23 | return candle 24 | 25 | 26 | def generate_candle_with_price_and_symbol( 27 | symbol: str, time_span: TimeSpan, timestamp: datetime, price: float 28 | ) -> Candle: 29 | candle = generate_candle_with_symbol(symbol, time_span, timestamp) 30 | candle.open = candle.close = candle.high = candle.low = candle.volume = price 31 | return candle 32 | -------------------------------------------------------------------------------- /tests/unit/strategies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idanya/algo-trader/742589a488f02bbffd6bc576c08357622084b7f9/tests/unit/strategies/__init__.py -------------------------------------------------------------------------------- /tests/unit/strategies/test_simple_sma.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List, Tuple 3 | from unittest import TestCase 4 | 5 | from algotrader.entities.candle import Candle 6 | from algotrader.entities.strategy_signal import StrategySignal, SignalDirection 7 | from algotrader.entities.timespan import TimeSpan 8 | from fakes.strategy_executor import FakeSignalsExecutor 9 | from algotrader.pipeline.processors.candle_cache import CandleCache 10 | from algotrader.pipeline.processors.strategy import StrategyProcessor 11 | from algotrader.pipeline.processors.technicals import INDICATORS_ATTACHMENT_KEY 12 | from algotrader.entities.attachments.technicals import Indicators 13 | from algotrader.pipeline.shared_context import SharedContext 14 | from algotrader.pipeline.strategies.simple_sma import SimpleSMA 15 | from unit import generate_candle, TEST_SYMBOL 16 | 17 | 18 | class TestSimpleSMAStrategy(TestCase): 19 | def test_long(self): 20 | def _check(signals: List[StrategySignal]): 21 | self.assertEqual(1, len(signals)) 22 | self.assertEqual(TEST_SYMBOL, signals[0].symbol) 23 | self.assertEqual(SignalDirection.Long, signals[0].direction) 24 | 25 | prev_candle, current_candle = self._get_candles() 26 | 27 | context = SharedContext() 28 | cache_processor = CandleCache(None) 29 | cache_processor.process(context, prev_candle) 30 | 31 | processor = StrategyProcessor([SimpleSMA()], FakeSignalsExecutor(_check), None) 32 | processor.process(context, current_candle) 33 | 34 | def test_short(self): 35 | def _check(signals: List[StrategySignal]): 36 | self.assertEqual(1, len(signals)) 37 | self.assertEqual(TEST_SYMBOL, signals[0].symbol) 38 | self.assertEqual(SignalDirection.Short, signals[0].direction) 39 | 40 | current_candle, prev_candle = self._get_candles() 41 | 42 | context = SharedContext() 43 | cache_processor = CandleCache(None) 44 | cache_processor.process(context, prev_candle) 45 | 46 | processor = StrategyProcessor([SimpleSMA()], FakeSignalsExecutor(_check), None) 47 | processor.process(context, current_candle) 48 | 49 | def _get_candles(self) -> Tuple[Candle, Candle]: 50 | prev_candle = generate_candle(TimeSpan.Day, datetime.now()) 51 | current_candle = generate_candle(TimeSpan.Day, datetime.now()) 52 | 53 | prev_indicators = Indicators() 54 | prev_indicators.set("sma5", 5) 55 | prev_indicators.set("sma20", 6) 56 | 57 | current_indicators = Indicators() 58 | current_indicators.set("sma5", 6) 59 | current_indicators.set("sma20", 5) 60 | 61 | prev_candle.add_attachment(INDICATORS_ATTACHMENT_KEY, prev_indicators) 62 | current_candle.add_attachment(INDICATORS_ATTACHMENT_KEY, current_indicators) 63 | return prev_candle, current_candle 64 | -------------------------------------------------------------------------------- /tests/unit/test_asset_correlation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from datetime import datetime, timedelta 4 | from typing import List 5 | from unittest import TestCase 6 | 7 | from algotrader.calc.calculations import TechnicalCalculation 8 | from algotrader.entities.candle import Candle 9 | from algotrader.entities.timespan import TimeSpan 10 | from algotrader.pipeline.configs.indicator_config import IndicatorConfig 11 | from algotrader.pipeline.configs.technical_processor_config import TechnicalsProcessorConfig 12 | from algotrader.pipeline.pipeline import Pipeline 13 | from algotrader.pipeline.processors.assets_correlation import ( 14 | CORRELATIONS_ATTACHMENT_KEY, 15 | AssetCorrelationProcessor, 16 | ) 17 | from algotrader.entities.attachments.assets_correlation import AssetCorrelation 18 | from algotrader.pipeline.processors.candle_cache import CandleCache 19 | from algotrader.pipeline.processors.technicals import TechnicalsProcessor 20 | from algotrader.pipeline.processors.timespan_change import TimeSpanChangeProcessor 21 | from algotrader.pipeline.runner import PipelineRunner 22 | from algotrader.pipeline.shared_context import SharedContext 23 | from fakes.pipeline_validators import ValidationProcessor 24 | from fakes.source import FakeSource 25 | from unit import generate_candle_with_price_and_symbol 26 | 27 | 28 | class TestAssetCorrelationProcessor(TestCase): 29 | def setUp(self) -> None: 30 | super().setUp() 31 | x = [ 32 | generate_candle_with_price_and_symbol( 33 | "X", TimeSpan.Day, datetime.now() - timedelta(days=c), c + random.randint(1, 10) 34 | ) 35 | for c in range(1, 49) 36 | ] 37 | y = [ 38 | generate_candle_with_price_and_symbol( 39 | "Y", TimeSpan.Day, datetime.now() - timedelta(days=c), c + random.randint(1, 10) 40 | ) 41 | for c in range(1, 49) 42 | ] 43 | z = [ 44 | generate_candle_with_price_and_symbol( 45 | "Z", TimeSpan.Day, datetime.now() - timedelta(days=c), c + random.randint(1, 10) 46 | ) 47 | for c in range(1, 49) 48 | ] 49 | 50 | merged: List[Candle] = [] 51 | for i in range(len(x)): 52 | merged.append(x[i]) 53 | merged.append(y[i]) 54 | merged.append(z[i]) 55 | 56 | self.source = FakeSource(merged) 57 | 58 | def test_correlation(self): 59 | def _check(context: SharedContext, candle: Candle): 60 | self.assertIsNotNone(context) 61 | context.put_kv_data("check_count", context.get_kv_data("check_count", 0) + 1) 62 | 63 | check_count = context.get_kv_data("check_count", 0) 64 | if check_count > 20: 65 | cache_reader = CandleCache.context_reader(context) 66 | latest_candle = cache_reader.get_symbol_candles(candle.symbol)[-2] 67 | asset_correlation: AssetCorrelation = latest_candle.get_attachment(CORRELATIONS_ATTACHMENT_KEY) 68 | if candle.symbol == "X": 69 | self.assertFalse(asset_correlation.has("X")) 70 | self.assertTrue(asset_correlation.has("Y")) 71 | self.assertTrue(asset_correlation.has("Z")) 72 | else: 73 | self.assertTrue(asset_correlation.has("X")) 74 | 75 | validator = ValidationProcessor(_check) 76 | cache_processor = CandleCache(validator) 77 | correlations_file_path = os.path.join( 78 | os.path.dirname(os.path.abspath(__file__)), "../configs/correlations.json" 79 | ) 80 | asset_correlation = AssetCorrelationProcessor(correlations_file_path, cache_processor) 81 | timespan_change_processor = TimeSpanChangeProcessor(TimeSpan.Day, asset_correlation) 82 | 83 | config = TechnicalsProcessorConfig([ 84 | IndicatorConfig("sma5", TechnicalCalculation.SMA, [5]), 85 | ]) 86 | 87 | technicals = TechnicalsProcessor(config, timespan_change_processor) 88 | PipelineRunner(Pipeline(self.source, technicals)).run() 89 | -------------------------------------------------------------------------------- /tests/unit/test_assets_provider.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | from algotrader.assets.assets_provider import AssetsProvider 4 | 5 | 6 | class TestAssestProvider(TestCase): 7 | def test_get_sp500(self): 8 | symbols = AssetsProvider.get_sp500_symbols() 9 | self.assertEqual(495, len(symbols)) 10 | self.assertTrue("AAPL" in symbols) 11 | self.assertTrue("AMD" in symbols) 12 | self.assertTrue("WYNN" in symbols) 13 | -------------------------------------------------------------------------------- /tests/unit/test_async_query_result.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from unittest import TestCase 3 | 4 | from algotrader.entities.timespan import TimeSpan 5 | from algotrader.market.async_query_result import AsyncQueryResult 6 | from algotrader.providers.ib.query_subscription import QuerySubscription 7 | from unit import generate_candle 8 | 9 | 10 | class TestAsyncQueryResult(TestCase): 11 | def setUp(self) -> None: 12 | super().setUp() 13 | self.from_ts = datetime.now() - timedelta(days=10) 14 | self.to_ts = datetime.now() 15 | self.result = AsyncQueryResult(self.from_ts, self.to_ts) 16 | 17 | def test_candles_out_of_range(self): 18 | subscription = QuerySubscription(1, "X", TimeSpan.Day) 19 | self.result.attach_query_subscription(subscription) 20 | subscription.push_candles([ 21 | generate_candle(TimeSpan.Day, self.from_ts), 22 | generate_candle(TimeSpan.Day, self.to_ts), 23 | generate_candle(TimeSpan.Day, self.from_ts - timedelta(days=1)), 24 | ]) 25 | 26 | subscription.done() 27 | 28 | self.assertEqual(2, len(self.result.result())) 29 | 30 | def test_multiple_subscriptions(self): 31 | subscription1 = QuerySubscription(1, "X", TimeSpan.Day) 32 | self.result.attach_query_subscription(subscription1) 33 | 34 | subscription2 = QuerySubscription(2, "X", TimeSpan.Day) 35 | self.result.attach_query_subscription(subscription2) 36 | 37 | subscription1.push_candles([ 38 | generate_candle(TimeSpan.Day, self.from_ts), 39 | generate_candle(TimeSpan.Day, self.to_ts), 40 | generate_candle(TimeSpan.Day, self.from_ts - timedelta(days=1)), 41 | ]) 42 | 43 | subscription2.push_candles([ 44 | generate_candle(TimeSpan.Day, self.from_ts), 45 | generate_candle(TimeSpan.Day, self.to_ts), 46 | generate_candle(TimeSpan.Day, self.from_ts - timedelta(days=1)), 47 | ]) 48 | 49 | subscription1.done() 50 | subscription2.done() 51 | 52 | self.assertEqual(4, len(self.result.result())) 53 | 54 | def test_multiple_subscriptions_with_error(self): 55 | subscription1 = QuerySubscription(1, "X", TimeSpan.Day) 56 | self.result.attach_query_subscription(subscription1) 57 | 58 | subscription2 = QuerySubscription(2, "X", TimeSpan.Day) 59 | self.result.attach_query_subscription(subscription2) 60 | 61 | subscription1.push_candles([ 62 | generate_candle(TimeSpan.Day, self.from_ts), 63 | generate_candle(TimeSpan.Day, self.to_ts), 64 | generate_candle(TimeSpan.Day, self.from_ts - timedelta(days=1)), 65 | ]) 66 | 67 | subscription2.push_candles([ 68 | generate_candle(TimeSpan.Day, self.from_ts), 69 | generate_candle(TimeSpan.Day, self.to_ts), 70 | generate_candle(TimeSpan.Day, self.from_ts - timedelta(days=1)), 71 | ]) 72 | 73 | subscription1.done() 74 | subscription2.done(True) 75 | 76 | with self.assertRaises(Exception): 77 | self.result.result() 78 | -------------------------------------------------------------------------------- /tests/unit/test_candle_cache.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from unittest import TestCase 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.timespan import TimeSpan 6 | from fakes.source import FakeSource 7 | from fakes.pipeline_validators import ValidationProcessor 8 | from algotrader.pipeline.pipeline import Pipeline 9 | from algotrader.pipeline.processors.candle_cache import CandleCache 10 | from algotrader.pipeline.runner import PipelineRunner 11 | from algotrader.pipeline.shared_context import SharedContext 12 | from unit import generate_candle 13 | 14 | 15 | class TestCandleCacheProcessor(TestCase): 16 | def setUp(self) -> None: 17 | super().setUp() 18 | self.test_candle = generate_candle(TimeSpan.Day, datetime.now()) 19 | self.source = FakeSource([self.test_candle]) 20 | 21 | def test(self): 22 | def _check(context: SharedContext, candle: Candle): 23 | self.assertIsNotNone(context) 24 | 25 | cache_reader = CandleCache.context_reader(context) 26 | cached_candles = cache_reader.get_symbol_candles(candle.symbol) 27 | self.assertEqual(self.test_candle.symbol, cached_candles[0].symbol) 28 | 29 | validator = ValidationProcessor(_check) 30 | processor = CandleCache(validator) 31 | PipelineRunner(Pipeline(self.source, processor)).run() 32 | -------------------------------------------------------------------------------- /tests/unit/test_filesink_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import tempfile 5 | from datetime import datetime 6 | from unittest import TestCase 7 | 8 | from algotrader.entities.candle import Candle 9 | from algotrader.entities.timespan import TimeSpan 10 | from fakes.pipeline_validators import TerminatorValidator 11 | from fakes.source import FakeSource 12 | from algotrader.pipeline.pipeline import Pipeline 13 | from algotrader.pipeline.processors.file_sink import FileSinkProcessor 14 | from algotrader.pipeline.runner import PipelineRunner 15 | from algotrader.pipeline.shared_context import SharedContext 16 | from unit import generate_candle_with_price 17 | 18 | 19 | class TestFileSinkProcessor(TestCase): 20 | def setUp(self) -> None: 21 | super().setUp() 22 | self.source = FakeSource([ 23 | generate_candle_with_price(TimeSpan.Day, datetime.now(), random.randint(0, c)) for c in range(1, 50) 24 | ]) 25 | 26 | def test(self): 27 | temp_file = tempfile.NamedTemporaryFile(delete=False) 28 | 29 | def _check(context: SharedContext): 30 | self.assertIsNotNone(context) 31 | lines = temp_file.readlines() 32 | self.assertEqual(49, len(lines)) 33 | for line in lines: 34 | candle = Candle.model_validate_json(json.loads(line)) 35 | self.assertEqual(TimeSpan.Day, candle.time_span) 36 | self.assertEqual(datetime.now().day, candle.timestamp.day) 37 | 38 | validator = TerminatorValidator(_check) 39 | 40 | processor = FileSinkProcessor(temp_file.name) 41 | PipelineRunner(Pipeline(self.source, processor, validator)).run() 42 | 43 | temp_file.close() 44 | os.unlink(temp_file.name) 45 | -------------------------------------------------------------------------------- /tests/unit/test_inmemory_storage.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from datetime import datetime 3 | from typing import List 4 | 5 | from algotrader.entities.candle import Candle 6 | from algotrader.entities.timespan import TimeSpan 7 | from algotrader.storage.inmemory_storage import InMemoryStorage 8 | from unit import generate_candle, TEST_SYMBOL 9 | 10 | 11 | class TestInMemoryStorage(unittest.TestCase): 12 | def setUp(self) -> None: 13 | super().setUp() 14 | self.inmemory_storage = InMemoryStorage() 15 | 16 | def test_save_single_candle(self): 17 | minute_candle = generate_candle(TimeSpan.Minute, datetime.now()) 18 | 19 | self.inmemory_storage.save(minute_candle) 20 | 21 | candles: List[Candle] = self.inmemory_storage.get_symbol_candles( 22 | symbol=TEST_SYMBOL, 23 | time_span=TimeSpan.Minute, 24 | from_timestamp=minute_candle.timestamp, 25 | to_timestamp=minute_candle.timestamp, 26 | ) 27 | 28 | self.assertEqual(1, len(candles)) 29 | self.assertEqual(TEST_SYMBOL, candles[0].symbol) 30 | self.assertEqual(TimeSpan.Minute, candles[0].time_span) 31 | self.assertEqual(minute_candle.timestamp, candles[0].timestamp) 32 | 33 | def test_save_different_timespans_candle(self): 34 | minute_candle = generate_candle(TimeSpan.Minute, datetime.now()) 35 | self.inmemory_storage.save(minute_candle) 36 | 37 | day_candle = generate_candle(TimeSpan.Day, minute_candle.timestamp) 38 | self.inmemory_storage.save(day_candle) 39 | 40 | candles: List[Candle] = self.inmemory_storage.get_symbol_candles( 41 | symbol=TEST_SYMBOL, 42 | time_span=TimeSpan.Minute, 43 | from_timestamp=minute_candle.timestamp, 44 | to_timestamp=minute_candle.timestamp, 45 | ) 46 | 47 | self.assertEqual(1, len(candles)) 48 | self.assertEqual(TimeSpan.Minute, candles[0].time_span) 49 | self.assertEqual(minute_candle.timestamp, candles[0].timestamp) 50 | -------------------------------------------------------------------------------- /tests/unit/test_mongo_source.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from unittest import TestCase 3 | 4 | import mongomock 5 | 6 | from algotrader.entities.timespan import TimeSpan 7 | from algotrader.pipeline.sources.mongodb_source import MongoDBSource 8 | from algotrader.storage.mongodb_storage import MongoDBStorage 9 | from unit import generate_candle, TEST_SYMBOL 10 | 11 | 12 | class TestMongoSource(TestCase): 13 | @mongomock.patch(servers=(("localhost", 27017),)) 14 | def setUp(self) -> None: 15 | super().setUp() 16 | self.mongo_storage = MongoDBStorage() 17 | self.mongo_storage.__drop_collections__() 18 | 19 | def test(self): 20 | for i in range(10): 21 | self.mongo_storage.save(generate_candle(TimeSpan.Day, datetime.now() - timedelta(minutes=i))) 22 | 23 | from_time = datetime.now() - timedelta(days=1) 24 | to_time = datetime.now() 25 | source = MongoDBSource(self.mongo_storage, [TEST_SYMBOL], TimeSpan.Day, from_time, to_time) 26 | 27 | candles = list(source.read()) 28 | self.assertEqual(len(candles), 10) 29 | 30 | for candle in candles: 31 | self.assertEqual(TEST_SYMBOL, candle.symbol) 32 | self.assertEqual(TimeSpan.Day, candle.time_span) 33 | self.assertTrue(candle.timestamp < datetime.now()) 34 | -------------------------------------------------------------------------------- /tests/unit/test_mongodb_sink_processor.py: -------------------------------------------------------------------------------- 1 | import random 2 | from datetime import datetime, timedelta 3 | from unittest import TestCase 4 | 5 | import mongomock 6 | 7 | from algotrader.entities.timespan import TimeSpan 8 | from fakes.pipeline_validators import TerminatorValidator 9 | from fakes.source import FakeSource 10 | from algotrader.pipeline.pipeline import Pipeline 11 | from algotrader.pipeline.processors.storage_provider_sink import StorageSinkProcessor 12 | from algotrader.pipeline.runner import PipelineRunner 13 | from algotrader.pipeline.shared_context import SharedContext 14 | from algotrader.storage.mongodb_storage import MongoDBStorage 15 | from unit import generate_candle_with_price, TEST_SYMBOL 16 | 17 | 18 | class TestMongoDBSinkProcessor(TestCase): 19 | def setUp(self) -> None: 20 | super().setUp() 21 | self.source = FakeSource([ 22 | generate_candle_with_price(TimeSpan.Day, datetime.now() - timedelta(minutes=c), random.randint(0, c)) 23 | for c in range(1, 50) 24 | ]) 25 | 26 | @mongomock.patch(servers=(("localhost", 27017),)) 27 | def test(self): 28 | mogodb_storage = MongoDBStorage() 29 | mogodb_storage.__drop_collections__() 30 | 31 | def _check(context: SharedContext): 32 | self.assertIsNotNone(context) 33 | candles = mogodb_storage.get_symbol_candles( 34 | TEST_SYMBOL, TimeSpan.Day, datetime.now() - timedelta(days=1), datetime.now() 35 | ) 36 | self.assertEqual(49, len(candles)) 37 | 38 | validator = TerminatorValidator(_check) 39 | processor = StorageSinkProcessor(mogodb_storage) 40 | PipelineRunner(Pipeline(self.source, processor, validator)).run() 41 | -------------------------------------------------------------------------------- /tests/unit/test_mongodb_storage.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from typing import List 3 | from unittest import TestCase 4 | 5 | import mongomock 6 | 7 | from algotrader.entities.candle import Candle 8 | from algotrader.entities.timespan import TimeSpan 9 | from algotrader.storage.mongodb_storage import MongoDBStorage 10 | from unit import generate_candle, TEST_SYMBOL 11 | 12 | 13 | class TestMongoDBStorage(TestCase): 14 | @mongomock.patch(servers=(("localhost", 27017),)) 15 | def setUp(self) -> None: 16 | super().setUp() 17 | self.mogodb_storage = MongoDBStorage() 18 | self.mogodb_storage.__drop_collections__() 19 | 20 | def test_save_single_candle(self): 21 | minute_candle = generate_candle(TimeSpan.Minute, datetime.now().replace(microsecond=0)) 22 | 23 | self.mogodb_storage.save(minute_candle) 24 | 25 | candles: List[Candle] = self.mogodb_storage.get_symbol_candles( 26 | symbol=TEST_SYMBOL, 27 | time_span=TimeSpan.Minute, 28 | from_timestamp=minute_candle.timestamp, 29 | to_timestamp=minute_candle.timestamp, 30 | ) 31 | 32 | self.assertEqual(1, len(candles)) 33 | self.assertEqual(TEST_SYMBOL, candles[0].symbol) 34 | self.assertEqual(TimeSpan.Minute, candles[0].time_span) 35 | self.assertEqual(minute_candle.timestamp, candles[0].timestamp) 36 | 37 | def test_save_different_timespans_candle(self): 38 | minute_candle = generate_candle(TimeSpan.Minute, datetime.now().replace(microsecond=0)) 39 | self.mogodb_storage.save(minute_candle) 40 | 41 | day_candle = generate_candle(TimeSpan.Day, minute_candle.timestamp.replace(microsecond=0)) 42 | self.mogodb_storage.save(day_candle) 43 | 44 | candles: List[Candle] = self.mogodb_storage.get_symbol_candles( 45 | symbol=TEST_SYMBOL, 46 | time_span=TimeSpan.Minute, 47 | from_timestamp=minute_candle.timestamp, 48 | to_timestamp=minute_candle.timestamp, 49 | ) 50 | 51 | self.assertEqual(1, len(candles)) 52 | self.assertEqual(TimeSpan.Minute, candles[0].time_span) 53 | self.assertEqual(minute_candle.timestamp, candles[0].timestamp) 54 | 55 | def test_sorted_results(self): 56 | minute_candle = generate_candle(TimeSpan.Minute, datetime.now().replace(microsecond=0)) 57 | next_minute_candle = generate_candle( 58 | TimeSpan.Minute, (datetime.now() + timedelta(minutes=1)).replace(microsecond=0) 59 | ) 60 | 61 | self.mogodb_storage.save(next_minute_candle) 62 | self.mogodb_storage.save(minute_candle) 63 | 64 | candles: List[Candle] = self.mogodb_storage.get_symbol_candles( 65 | symbol=TEST_SYMBOL, 66 | time_span=TimeSpan.Minute, 67 | from_timestamp=minute_candle.timestamp, 68 | to_timestamp=next_minute_candle.timestamp, 69 | ) 70 | 71 | self.assertEqual(2, len(candles)) 72 | self.assertEqual(candles[0].timestamp, minute_candle.timestamp) 73 | self.assertEqual(candles[1].timestamp, next_minute_candle.timestamp) 74 | -------------------------------------------------------------------------------- /tests/unit/test_multiple_pipelines.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from unittest import TestCase 3 | 4 | from algotrader.entities.timespan import TimeSpan 5 | from fakes.pipeline_validators import TerminatorValidator 6 | from fakes.source import FakeSource 7 | from algotrader.pipeline.pipeline import Pipeline 8 | from algotrader.pipeline.processor import Processor 9 | from algotrader.pipeline.runner import PipelineRunner 10 | from algotrader.pipeline.shared_context import SharedContext 11 | from unit import generate_candle_with_price 12 | 13 | 14 | class TestMultiplePipelines(TestCase): 15 | def setUp(self) -> None: 16 | super().setUp() 17 | 18 | def test_multiple_pipelines(self): 19 | def _check_pipeline_one(context: SharedContext): 20 | self.assertIsNotNone(context) 21 | context.put_kv_data("check", True) 22 | 23 | def _check_pipeline_two(context: SharedContext): 24 | self.assertIsNotNone(context) 25 | check = context.get_kv_data("check") 26 | self.assertTrue(check) 27 | 28 | source = FakeSource([generate_candle_with_price(TimeSpan.Day, datetime.now(), 1)]) 29 | processor = Processor() 30 | validator_one = TerminatorValidator(_check_pipeline_one) 31 | validator_two = TerminatorValidator(_check_pipeline_two) 32 | 33 | pipeline_one = Pipeline(source, processor, validator_one) 34 | pipeline_two = Pipeline(source, processor, validator_two) 35 | 36 | pipelines = [pipeline_one, pipeline_two] 37 | 38 | PipelineRunner(pipelines).run() 39 | -------------------------------------------------------------------------------- /tests/unit/test_returns_calculator_processor.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from unittest import TestCase 3 | 4 | from algotrader.entities.timespan import TimeSpan 5 | from algotrader.pipeline.pipeline import Pipeline 6 | from algotrader.pipeline.processors.candle_cache import CandleCache 7 | from algotrader.pipeline.processors.returns import ReturnsCalculatorProcessor, RETURNS_ATTACHMENT_KEY 8 | from algotrader.pipeline.reverse_source import ReverseSource 9 | from algotrader.pipeline.runner import PipelineRunner 10 | from algotrader.pipeline.shared_context import SharedContext 11 | from fakes.pipeline_validators import TerminatorValidator 12 | from fakes.source import FakeSource 13 | from unit import generate_candle_with_price, TEST_SYMBOL 14 | 15 | 16 | class TestReturnsCalculatorProcessor(TestCase): 17 | def setUp(self) -> None: 18 | super().setUp() 19 | self.source = FakeSource([ 20 | generate_candle_with_price(TimeSpan.Day, datetime.now() + timedelta(minutes=c), c) for c in range(1, 50) 21 | ]) 22 | 23 | def test(self): 24 | def _check_returns(context: SharedContext): 25 | self.assertIsNotNone(context) 26 | cache_reader = CandleCache.context_reader(context) 27 | candles = cache_reader.get_symbol_candles(TEST_SYMBOL) 28 | 29 | self.assertFalse(candles[0].get_attachment(RETURNS_ATTACHMENT_KEY).has("ctc-1")) 30 | self.assertFalse(candles[1].get_attachment(RETURNS_ATTACHMENT_KEY).has("ctc-1")) 31 | self.assertFalse(candles[2].get_attachment(RETURNS_ATTACHMENT_KEY).has("ctc-1")) 32 | 33 | ctc1 = candles[3].get_attachment(RETURNS_ATTACHMENT_KEY)["ctc-1"] 34 | ctc2 = candles[3].get_attachment(RETURNS_ATTACHMENT_KEY)["ctc-2"] 35 | ctc3 = candles[3].get_attachment(RETURNS_ATTACHMENT_KEY)["ctc-3"] 36 | self.assertTrue(ctc1 < ctc2 < ctc3) 37 | 38 | cache_processor = CandleCache() 39 | processor = ReturnsCalculatorProcessor("ctc", 3, cache_processor) 40 | 41 | terminator = TerminatorValidator(_check_returns) 42 | 43 | self.source = ReverseSource(self.source) 44 | PipelineRunner(Pipeline(self.source, processor, terminator)).run() 45 | -------------------------------------------------------------------------------- /tests/unit/test_reverse_source.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from unittest import TestCase 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.timespan import TimeSpan 6 | from fakes.pipeline_validators import ValidationProcessor 7 | from fakes.source import FakeSource 8 | from algotrader.pipeline.pipeline import Pipeline 9 | from algotrader.pipeline.reverse_source import ReverseSource 10 | from algotrader.pipeline.runner import PipelineRunner 11 | from algotrader.pipeline.shared_context import SharedContext 12 | from unit import generate_candle_with_price 13 | 14 | 15 | class TestReverseSource(TestCase): 16 | def setUp(self) -> None: 17 | super().setUp() 18 | self.source = FakeSource([generate_candle_with_price(TimeSpan.Day, datetime.now(), c) for c in range(1, 50)]) 19 | 20 | def test_regular_order(self): 21 | def _check(context: SharedContext, candle: Candle): 22 | self.assertIsNotNone(context) 23 | 24 | last_price = context.get_kv_data("last_price") 25 | if last_price: 26 | self.assertTrue(candle.close > last_price) 27 | 28 | context.put_kv_data("last_price", candle.close) 29 | 30 | validator = ValidationProcessor(_check) 31 | PipelineRunner(Pipeline(self.source, validator)).run() 32 | 33 | def test_reverse_order(self): 34 | def _check(context: SharedContext, candle: Candle): 35 | self.assertIsNotNone(context) 36 | 37 | last_price = context.get_kv_data("last_price") 38 | if last_price: 39 | self.assertTrue(candle.close < last_price) 40 | 41 | context.put_kv_data("last_price", candle.close) 42 | 43 | validator = ValidationProcessor(_check) 44 | PipelineRunner(Pipeline(ReverseSource(self.source), validator)).run() 45 | -------------------------------------------------------------------------------- /tests/unit/test_serialization.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from unittest import TestCase 3 | 4 | import mongomock 5 | 6 | from algotrader.entities.timespan import TimeSpan 7 | from algotrader.pipeline.processors.candle_cache import CandleCache 8 | from algotrader.pipeline.processors.technicals_normalizer import TechnicalsNormalizerProcessor 9 | from algotrader.pipeline.sources.mongodb_source import MongoDBSource 10 | from algotrader.storage.mongodb_storage import MongoDBStorage 11 | 12 | 13 | class TestSerialization(TestCase): 14 | def test_serialize_processor(self): 15 | candle_cache_processor = CandleCache(CandleCache()) 16 | serialized = candle_cache_processor.serialize() 17 | self.assertEqual("algotrader.pipeline.processors.candle_cache:CandleCache", serialized["__class__"]) 18 | self.assertEqual( 19 | "algotrader.pipeline.processors.candle_cache:CandleCache", serialized["next_processor"]["__class__"] 20 | ) 21 | 22 | deserialized: CandleCache = CandleCache.deserialize(serialized) 23 | self.assertIsNotNone(deserialized) 24 | self.assertIsInstance(deserialized, CandleCache) 25 | self.assertIsInstance(deserialized.next_processor, CandleCache) 26 | 27 | def test_serialize_with_ctor(self): 28 | tech_buckets_matcher = TechnicalsNormalizerProcessor(666) 29 | serialized = tech_buckets_matcher.serialize() 30 | 31 | deserialized: TechnicalsNormalizerProcessor = TechnicalsNormalizerProcessor.deserialize(serialized) 32 | self.assertEqual(666, deserialized.normalization_window_size) 33 | 34 | def test_serialize_with_nested_ctor(self): 35 | tech_buckets_matcher = TechnicalsNormalizerProcessor(666) 36 | candle_cache_processor = CandleCache(tech_buckets_matcher) 37 | serialized = candle_cache_processor.serialize() 38 | 39 | deserialized: CandleCache = CandleCache.deserialize(serialized) 40 | self.assertIsInstance(deserialized, CandleCache) 41 | self.assertIsInstance(deserialized.next_processor, TechnicalsNormalizerProcessor) 42 | self.assertEqual(deserialized.next_processor.normalization_window_size, 666) 43 | 44 | @mongomock.patch(servers=(("host", 666),)) 45 | def test_serialize_complex_source(self): 46 | from_time = datetime.now() - timedelta(minutes=10) 47 | to_time = datetime.now() 48 | mongo_storage = MongoDBStorage("host", 666, "db") 49 | mongo_source = MongoDBSource(mongo_storage, ["X", "Y"], TimeSpan.Day, from_time, to_time) 50 | 51 | serialized = mongo_source.serialize() 52 | deserialized: MongoDBSource = MongoDBSource.deserialize(serialized) 53 | self.assertEqual("host", deserialized.mongo_storage.host) 54 | self.assertEqual(666, deserialized.mongo_storage.port) 55 | self.assertEqual("db", deserialized.mongo_storage.database) 56 | self.assertEqual(from_time, deserialized.from_time) 57 | self.assertEqual(to_time, deserialized.to_time) 58 | -------------------------------------------------------------------------------- /tests/unit/test_serializations.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from unittest import TestCase 3 | 4 | from algotrader.entities.attachments.nothing import NothingClass 5 | from algotrader.entities.bucket import Bucket 6 | from algotrader.entities.bucketscontainer import BucketsContainer 7 | from algotrader.entities.candle import Candle 8 | from algotrader.entities.timespan import TimeSpan 9 | from unit import generate_candle_with_price 10 | 11 | 12 | class TestSerializations(TestCase): 13 | def test_candle(self): 14 | candle = generate_candle_with_price(TimeSpan.Day, datetime.now(), 888) 15 | data = candle.model_dump_json() 16 | 17 | new_candle = Candle.model_validate_json(data) 18 | 19 | self.assertEqual(candle.symbol, new_candle.symbol) 20 | self.assertEqual(candle.timestamp, new_candle.timestamp) 21 | self.assertEqual(candle.time_span, new_candle.time_span) 22 | self.assertEqual(candle.close, new_candle.close) 23 | self.assertEqual(candle.high, new_candle.high) 24 | self.assertEqual(candle.low, new_candle.low) 25 | self.assertEqual(candle.volume, new_candle.volume) 26 | self.assertEqual(candle.open, new_candle.open) 27 | 28 | def test_candle_attachments(self): 29 | candle = generate_candle_with_price(TimeSpan.Day, datetime.now(), 888) 30 | candle.add_attachment("key1", NothingClass()) 31 | 32 | data = candle.model_dump_json() 33 | new_candle = Candle.model_validate_json(data) 34 | 35 | self.assertEqual(candle.symbol, new_candle.symbol) 36 | original_attachment = candle.get_attachment("key1") 37 | new_attachment = new_candle.get_attachment("key1") 38 | self.assertEqual(original_attachment.__class__, new_attachment.__class__) 39 | 40 | def test_bins(self): 41 | bins = BucketsContainer() 42 | bins.add("x", [Bucket(ident=1, start=1, end=2)]) 43 | bins.add("list", [[Bucket(ident=0, start=1, end=2)], [Bucket(ident=1, start=3, end=4)]]) 44 | 45 | serialized_data = bins.model_dump_json() 46 | new_bins = BucketsContainer.model_validate_json(serialized_data) 47 | 48 | x = new_bins.get("x") 49 | self.assertIsNotNone(x) 50 | self.assertEqual(1, x[0].start) 51 | self.assertEqual(2, x[0].end) 52 | 53 | lst = new_bins.get("list") 54 | self.assertIsNotNone(lst) 55 | self.assertTrue(isinstance(lst[0], list)) 56 | self.assertEqual(1, lst[0][0].start) 57 | self.assertEqual(2, lst[0][0].end) 58 | self.assertEqual(3, lst[1][0].start) 59 | self.assertEqual(4, lst[1][0].end) 60 | -------------------------------------------------------------------------------- /tests/unit/test_simple_sum_signals_executor.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from unittest import TestCase 3 | 4 | from algotrader.entities.strategy_signal import StrategySignal, SignalDirection 5 | from algotrader.entities.timespan import TimeSpan 6 | from algotrader.trade.simple_sum_signals_executor import DEFAULT_ORDER_VALUE, SimpleSumSignalsExecutor 7 | from unit import TEST_SYMBOL, generate_candle_with_price 8 | 9 | 10 | class TestSimpleSumSignalsExecutor(TestCase): 11 | def test_open_and_close_long(self): 12 | executor = SimpleSumSignalsExecutor() 13 | candle = generate_candle_with_price(TimeSpan.Day, datetime.now(), DEFAULT_ORDER_VALUE) 14 | signal = StrategySignal(TEST_SYMBOL, SignalDirection.Long) 15 | # test opening a long trade 16 | executor.execute(candle, [signal]) 17 | self.assertEqual(1, executor.position[TEST_SYMBOL]) 18 | self.assertEqual(-DEFAULT_ORDER_VALUE, executor.cash) 19 | # test closing a long trade 20 | candle = generate_candle_with_price(TimeSpan.Day, datetime.now(), DEFAULT_ORDER_VALUE + 88) 21 | executor.execute(candle, []) 22 | self.assertEqual(0, executor.position[TEST_SYMBOL]) 23 | self.assertEqual(88, executor.cash) 24 | 25 | def test_open_and_close_short(self): 26 | executor = SimpleSumSignalsExecutor() 27 | candle = generate_candle_with_price(TimeSpan.Day, datetime.now(), DEFAULT_ORDER_VALUE) 28 | signal = StrategySignal(TEST_SYMBOL, SignalDirection.Short) 29 | # test opening a short trade 30 | executor.execute(candle, [signal]) 31 | self.assertEqual(-1, executor.position[TEST_SYMBOL]) 32 | self.assertEqual(DEFAULT_ORDER_VALUE, executor.cash) 33 | # test closing a short trade 34 | candle = generate_candle_with_price(TimeSpan.Day, datetime.now(), DEFAULT_ORDER_VALUE - 88) 35 | executor.execute(candle, []) 36 | self.assertEqual(0, executor.position[TEST_SYMBOL]) 37 | self.assertEqual(88, executor.cash) 38 | -------------------------------------------------------------------------------- /tests/unit/test_strategy_processor.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List 3 | from unittest import TestCase 4 | 5 | from algotrader.entities.candle import Candle 6 | from algotrader.entities.strategy import Strategy 7 | from algotrader.entities.strategy_signal import StrategySignal, SignalDirection 8 | from algotrader.entities.timespan import TimeSpan 9 | from fakes.strategy_executor import FakeSignalsExecutor 10 | from algotrader.pipeline.processors.strategy import StrategyProcessor 11 | from algotrader.pipeline.shared_context import SharedContext 12 | from unit import TEST_SYMBOL, generate_candle 13 | 14 | 15 | class DummyStrategy(Strategy): 16 | def process(self, context: SharedContext, candle: Candle) -> List[StrategySignal]: 17 | return [StrategySignal(candle.symbol, SignalDirection.Long)] 18 | 19 | 20 | class NoSignalStrategy(Strategy): 21 | def process(self, context: SharedContext, candle: Candle) -> List[StrategySignal]: 22 | return [] 23 | 24 | 25 | class TestStrategyProcessor(TestCase): 26 | def test_signal_strategy(self): 27 | def _check(signals: List[StrategySignal]): 28 | self.assertEqual(1, len(signals)) 29 | self.assertEqual(SignalDirection.Long, signals[0].direction) 30 | self.assertEqual(TEST_SYMBOL, signals[0].symbol) 31 | 32 | candle = generate_candle(TimeSpan.Day, datetime.now()) 33 | processor = StrategyProcessor([DummyStrategy()], FakeSignalsExecutor(_check), None) 34 | processor.process(SharedContext(), candle) 35 | 36 | def test_multiple_strategies(self): 37 | def _check(signals: List[StrategySignal]): 38 | self.assertEqual(3, len(signals)) 39 | for i in range(3): 40 | self.assertEqual(SignalDirection.Long, signals[i].direction) 41 | self.assertEqual(TEST_SYMBOL, signals[i].symbol) 42 | 43 | candle = generate_candle(TimeSpan.Day, datetime.now()) 44 | processor = StrategyProcessor([DummyStrategy()] * 3, FakeSignalsExecutor(_check), None) 45 | processor.process(SharedContext(), candle) 46 | 47 | def test_no_signal(self): 48 | def _check(signals: List[StrategySignal]): 49 | self.assertEqual(0, len(signals)) 50 | 51 | candle = generate_candle(TimeSpan.Day, datetime.now()) 52 | processor = StrategyProcessor([NoSignalStrategy()], FakeSignalsExecutor(_check), None) 53 | processor.process(SharedContext(), candle) 54 | -------------------------------------------------------------------------------- /tests/unit/test_timespan_change_processor.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from unittest import TestCase 3 | 4 | from algotrader.entities.candle import Candle 5 | from algotrader.entities.event import Event 6 | from algotrader.entities.timespan import TimeSpan 7 | from algotrader.pipeline.pipeline import Pipeline 8 | from algotrader.pipeline.processors.candle_cache import CandleCache 9 | from algotrader.pipeline.processors.timespan_change import TimeSpanChangeProcessor 10 | from algotrader.pipeline.runner import PipelineRunner 11 | from algotrader.pipeline.shared_context import SharedContext 12 | from fakes.pipeline_validators import TerminatorValidator, ValidationProcessor 13 | from fakes.source import FakeSource 14 | from unit import generate_candle_with_price 15 | 16 | 17 | class TestTimeSpanChangeProcessor(TestCase): 18 | def setUp(self) -> None: 19 | super().setUp() 20 | self.source = FakeSource([ 21 | generate_candle_with_price(TimeSpan.Day, datetime.fromtimestamp(1669050000) - timedelta(hours=c), c) 22 | for c in range(1, 55) 23 | ]) 24 | 25 | def test(self): 26 | def _terminate(context: SharedContext): 27 | self.assertIsNotNone(context) 28 | event_count = context.get_kv_data("event_count", 0) 29 | candle_count = context.get_kv_data("candle_count", 0) 30 | self.assertEqual(event_count, 2) 31 | self.assertEqual(candle_count, 54) 32 | 33 | def _process(context: SharedContext, candle: Candle): 34 | self.assertIsNotNone(context) 35 | context.put_kv_data("candle_count", context.get_kv_data("candle_count", 0) + 1) 36 | 37 | def _event(context: SharedContext, event: Event): 38 | self.assertIsNotNone(context) 39 | 40 | if event != Event.TimeSpanChange: 41 | return 42 | 43 | context.put_kv_data("event_count", context.get_kv_data("event_count", 0) + 1) 44 | 45 | candle_count = context.get_kv_data("candle_count", 0) 46 | self.assertTrue(candle_count > 0) 47 | 48 | terminator = TerminatorValidator(_terminate) 49 | 50 | validator = ValidationProcessor(_process, _event) 51 | cache_processor = CandleCache(validator) 52 | processor = TimeSpanChangeProcessor(TimeSpan.Day, cache_processor) 53 | PipelineRunner(Pipeline(self.source, processor, terminator)).run() 54 | --------------------------------------------------------------------------------