├── .dockerignore ├── app ├── api │ ├── __init__.py │ └── api_v1 │ │ ├── __init__.py │ │ ├── endpoints │ │ ├── __init__.py │ │ ├── data.py │ │ └── strategy.py │ │ └── api.py ├── core │ ├── __init__.py │ └── commands │ │ ├── __init__.py │ │ ├── command.py │ │ └── strategy │ │ ├── __init__.py │ │ └── commands.py ├── tests │ ├── __init__.py │ ├── temp │ │ └── .keep │ ├── test_positions.py │ ├── test_adapter.py │ ├── test_command.py │ ├── test_position_testing.py │ ├── test_strategy.py │ ├── test_order_manager.py │ ├── test_ohlc.py │ ├── test_parameter.py │ ├── test_signals.py │ ├── test_order.py │ └── test_backtest.py ├── utils │ ├── __init__.py │ ├── loaders │ │ ├── __init__.py │ │ ├── load_all.py │ │ └── strategy_loader.py │ ├── commons.py │ ├── wiki_link.py │ ├── formatting.py │ └── create_file.py ├── benchmarks │ ├── __init__.py │ ├── profile_fill_orders.py │ └── profile_strategy_with_builtins.py ├── components │ ├── backtest │ │ ├── __init__.py │ │ ├── utils.py │ │ └── backtest.py │ ├── manager │ │ ├── __init__.py │ │ └── manager.py │ ├── ohlc │ │ ├── data_adapters │ │ │ ├── utils.py │ │ │ ├── __init__.py │ │ │ ├── api_adapter.py │ │ │ └── adapter.py │ │ ├── __init__.py │ │ ├── symbol.py │ │ └── ohlc.py │ ├── strategy │ │ ├── builtins │ │ │ ├── __init__.py │ │ │ └── ta │ │ │ │ ├── sma.py │ │ │ │ ├── logic.py │ │ │ │ ├── __init__.py │ │ │ │ ├── kalman_filter.py │ │ │ │ ├── atr.py │ │ │ │ └── correlation.py │ │ ├── templates │ │ │ ├── __init__.py │ │ │ └── strategy_template.py │ │ ├── __init__.py │ │ ├── decorators.py │ │ ├── series.py │ │ └── strategy.py │ ├── positions │ │ ├── __init__.py │ │ ├── enums.py │ │ ├── exceptions.py │ │ ├── utils.py │ │ ├── position_manager.py │ │ └── positions.py │ ├── orders │ │ ├── __init__.py │ │ ├── enums.py │ │ ├── signals.py │ │ ├── order_manager.py │ │ └── order.py │ ├── __init__.py │ └── parameter.py ├── settings.py ├── manage.py ├── Dockerfile ├── main.py ├── storage │ └── strategies │ │ └── examples │ │ ├── ohlc_demo.py │ │ ├── sma_cross_over.py │ │ ├── using_builtins.py │ │ └── sma_cross_over_advanced.py └── requirements.txt ├── .idea ├── vcs.xml ├── misc.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── modules.xml └── stratis-v2.iml ├── docker-compose-standalone.yml ├── ui ├── README.md └── Dockerfile ├── nginx.conf ├── docker-compose.yml ├── .github └── workflows │ └── pytest.yml ├── LICENSE.md ├── README.md └── .gitignore /.dockerignore: -------------------------------------------------------------------------------- 1 | .venv -------------------------------------------------------------------------------- /app/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/tests/temp/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/api/api_v1/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/tests/test_positions.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/core/commands/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/utils/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/api/api_v1/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/components/backtest/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/components/manager/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/components/ohlc/data_adapters/utils.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/components/strategy/builtins/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/components/strategy/templates/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/components/positions/__init__.py: -------------------------------------------------------------------------------- 1 | from components.positions.positions import Position -------------------------------------------------------------------------------- /app/components/orders/__init__.py: -------------------------------------------------------------------------------- 1 | from components.orders.order import Order, StopOrder, LimitOrder -------------------------------------------------------------------------------- /app/utils/commons.py: -------------------------------------------------------------------------------- 1 | STRATEGY_TEMPLATE_PATH = 'components/strategy/templates/strategy_template.py' -------------------------------------------------------------------------------- /app/components/positions/enums.py: -------------------------------------------------------------------------------- 1 | class PositionEffect: 2 | REDUCE = 'reduce' 3 | ADD = 'add' 4 | -------------------------------------------------------------------------------- /app/components/ohlc/data_adapters/__init__.py: -------------------------------------------------------------------------------- 1 | from components.ohlc.data_adapters.adapter import DataAdapter, CSVAdapter -------------------------------------------------------------------------------- /app/components/ohlc/__init__.py: -------------------------------------------------------------------------------- 1 | from components.ohlc.ohlc import OHLC 2 | from components.ohlc.data_adapters import DataAdapter, CSVAdapter -------------------------------------------------------------------------------- /app/utils/wiki_link.py: -------------------------------------------------------------------------------- 1 | def wiki_link(link: str) -> str: 2 | """Returns a link to the wiki.""" 3 | return ( 4 | f"\n\n\tSee Wiki: {link}\n\n" 5 | ) -------------------------------------------------------------------------------- /app/components/__init__.py: -------------------------------------------------------------------------------- 1 | from components.parameter import Parameter 2 | from components.strategy import Strategy 3 | from components.strategy.decorators import on_step 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /docker-compose-standalone.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | app: 5 | build: 6 | context: . 7 | dockerfile: app/Dockerfile 8 | ports: 9 | - "8000:8000" -------------------------------------------------------------------------------- /app/components/strategy/builtins/ta/sma.py: -------------------------------------------------------------------------------- 1 | from components.strategy import Series 2 | 3 | 4 | def sma(data, period) -> Series: 5 | """Simple Moving Average""" 6 | return Series(data.rolling(period).mean()) -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /app/settings.py: -------------------------------------------------------------------------------- 1 | from components.ohlc import CSVAdapter 2 | from components.ohlc.data_adapters.api_adapter import APIDataAdapter 3 | 4 | # add any custom data adapters here 5 | DATA_ADAPTERS = [ 6 | CSVAdapter, 7 | APIDataAdapter 8 | ] -------------------------------------------------------------------------------- /app/components/strategy/__init__.py: -------------------------------------------------------------------------------- 1 | from components.strategy.strategy import BaseStrategy as Strategy 2 | from components.strategy.decorators import on_step 3 | from components.strategy.series import Series 4 | from components.strategy.builtins import ta -------------------------------------------------------------------------------- /app/core/commands/command.py: -------------------------------------------------------------------------------- 1 | from utils.formatting import snake_case, pascal_case 2 | 3 | 4 | class BaseCommand: 5 | 6 | help = 'Base command' 7 | 8 | def handle(self, *args, **kwargs): 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /app/components/positions/exceptions.py: -------------------------------------------------------------------------------- 1 | class PositionValidationException(Exception): 2 | pass 3 | 4 | 5 | class PositionUnbalancedException(Exception): 6 | pass 7 | 8 | 9 | class PositionClosedException(Exception): 10 | pass 11 | -------------------------------------------------------------------------------- /ui/README.md: -------------------------------------------------------------------------------- 1 | # Stratis UI Repo 2 | 3 | [https://github.com/robswc/stratis-ui](https://github.com/robswc/stratis-ui) 4 | 5 | This directory simply contains the Dockerfile for the Stratis UI. 6 | The Dockerfile is used to build the image that is used to run the Stratis UI. -------------------------------------------------------------------------------- /app/manage.py: -------------------------------------------------------------------------------- 1 | """ 2 | A command-line utility for managing the app. 3 | """ 4 | from core.commands import strategy 5 | 6 | import typer 7 | 8 | app = typer.Typer() 9 | 10 | app.add_typer(strategy.app, name='strategy') 11 | 12 | if __name__ == '__main__': 13 | app() -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /app/components/strategy/builtins/ta/logic.py: -------------------------------------------------------------------------------- 1 | 2 | class Logic: 3 | @staticmethod 4 | def crossover(a: 'Series', b: 'Series') -> bool: 5 | return a > b and a.shift(1) < b.shift(1) 6 | 7 | @staticmethod 8 | def crossunder(a: 'Series', b: 'Series') -> bool: 9 | return a < b and a.shift(1) > b.shift(1) -------------------------------------------------------------------------------- /app/api/api_v1/api.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | 3 | from api.api_v1.endpoints import data 4 | from api.api_v1.endpoints import strategy 5 | 6 | api_router = APIRouter() 7 | api_router.include_router(strategy.router, prefix="/strategy", tags=["strategy"]) 8 | api_router.include_router(data.router, prefix="/data", tags=["data"]) -------------------------------------------------------------------------------- /app/components/strategy/builtins/ta/__init__.py: -------------------------------------------------------------------------------- 1 | from components.strategy.builtins.ta.sma import * 2 | from components.strategy.builtins.ta.correlation import * 3 | from components.strategy.builtins.ta.kalman_filter import * 4 | from components.strategy.builtins.ta.atr import * 5 | from components.strategy.builtins.ta.logic import Logic as logic -------------------------------------------------------------------------------- /app/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | COPY requirements.txt /tmp/requirements.txt 4 | RUN pip install -r /tmp/requirements.txt 5 | 6 | COPY . /app 7 | WORKDIR /app 8 | 9 | EXPOSE 8000 10 | 11 | # run the FastAPI app with gunicorn, using uvicorn workers 12 | CMD ["gunicorn", "-k", "uvicorn.workers.UvicornWorker", "-b", "0.0.0.0:8000", "main:app"] -------------------------------------------------------------------------------- /app/benchmarks/profile_fill_orders.py: -------------------------------------------------------------------------------- 1 | from components.ohlc import CSVAdapter 2 | from storage.strategies.examples.sma_cross_over_advanced import SMACrossOverAdvanced 3 | from storage.strategies.examples.using_builtins import UsingBuiltins 4 | 5 | adapter = CSVAdapter() 6 | ohlc = adapter.get_data(path='../tests/data/AAPL.csv') 7 | 8 | strategy = SMACrossOverAdvanced() 9 | strategy.run(ohlc) 10 | 11 | # 600ms 12 | -------------------------------------------------------------------------------- /nginx.conf: -------------------------------------------------------------------------------- 1 | events {} 2 | 3 | http { 4 | upstream app_backend { 5 | server app:8000; 6 | } 7 | 8 | upstream stratis_next_js { 9 | server stratis-next-js:3000; 10 | } 11 | 12 | server { 13 | location /api { 14 | proxy_pass http://app_backend; 15 | } 16 | 17 | location / { 18 | proxy_pass http://stratis_next_js; 19 | } 20 | } 21 | } -------------------------------------------------------------------------------- /app/benchmarks/profile_strategy_with_builtins.py: -------------------------------------------------------------------------------- 1 | from components.ohlc import CSVAdapter 2 | from storage.strategies.examples.using_builtins import UsingBuiltins 3 | 4 | adapter = CSVAdapter() 5 | ohlc = adapter.get_data(path='../tests/data/AAPL.csv') 6 | 7 | strategy = UsingBuiltins() 8 | strategy.run(ohlc) 9 | 10 | # 3/29/2023 11 | # first run: 1700ms 12 | # second run: 1400ms 13 | # third run: 1000ms 14 | # fourth run: 800ms 15 | # fifth run: 700ms 16 | # sixth run: sub 700ms 17 | -------------------------------------------------------------------------------- /app/components/strategy/templates/strategy_template.py: -------------------------------------------------------------------------------- 1 | from components import Strategy, on_step, Parameter 2 | 3 | 4 | class $StrategyName(Strategy): 5 | """Edit your strategy description here.""" 6 | 7 | # my_parameter = Parameter(10) # Example parameter, uncomment to use 8 | 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | # add any pre-compiled data here 12 | # https://github.com/robswc/stratis/wiki/Strategies#init-and-pre-compiled-data-series 13 | 14 | @on_step 15 | def do_stuff(self): 16 | print('Running on each step...') -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | app: 5 | build: 6 | context: ./app 7 | dockerfile: Dockerfile 8 | expose: 9 | - "8000" 10 | ports: 11 | - "8000:8000" 12 | 13 | stratis-next-js: 14 | depends_on: 15 | - app 16 | build: 17 | context: ./ui 18 | dockerfile: Dockerfile 19 | expose: 20 | - "3000" 21 | 22 | nginx: 23 | image: nginx:latest 24 | volumes: 25 | - ./nginx.conf:/etc/nginx/nginx.conf 26 | depends_on: 27 | - app 28 | - stratis-next-js 29 | ports: 30 | - "${NGINX_PORT:-3000}:80" -------------------------------------------------------------------------------- /ui/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM node:18-alpine AS base 2 | 3 | # Install git 4 | RUN apk add --no-cache git 5 | 6 | # Clone from GitHub 7 | ADD https://api.github.com/repos/robswc/stratis-ui latest_commit 8 | RUN git clone https://github.com/robswc/stratis-ui.git /app 9 | 10 | # Set working directory 11 | WORKDIR /app 12 | 13 | # Install dependencies 14 | RUN npm install 15 | 16 | # Add args 17 | ARG MAX_OLD_SPACE_SIZE=1024 18 | 19 | # Build NextJS app 20 | RUN npm run build --max-old-space-size=${MAX_OLD_SPACE_SIZE} 21 | 22 | # Expose port 3000 23 | EXPOSE 3000 24 | 25 | # Run NextJS app 26 | CMD ["npm", "run", "start"] 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from starlette.middleware.cors import CORSMiddleware 3 | 4 | from api.api_v1.api import api_router 5 | from utils.loaders import load_all # noqa: F401 6 | 7 | app = FastAPI( 8 | title="Stratis API", 9 | ) 10 | 11 | allowed_origins = [ 12 | "http://localhost", 13 | "http://localhost:8000", 14 | "http://localhost:3000", 15 | ] 16 | 17 | # handle CORS 18 | app.add_middleware( 19 | CORSMiddleware, 20 | allow_origins=allowed_origins, 21 | allow_credentials=True, 22 | allow_methods=["*"], 23 | allow_headers=["*"], 24 | ) 25 | 26 | app.include_router(api_router, prefix="/api/v1") -------------------------------------------------------------------------------- /app/core/commands/strategy/__init__.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | from core.commands.strategy.commands import CreateNewStrategy, ListStrategies 4 | 5 | app = typer.Typer( 6 | help="Manage strategies", 7 | ) 8 | 9 | 10 | @app.command() 11 | def create(name: str = typer.Argument(..., help=CreateNewStrategy.help)): 12 | """ 13 | Create a new strategy 14 | """ 15 | typer.echo(f"Creating strategy: {name}") 16 | CreateNewStrategy(strategy_name=name).handle() 17 | 18 | @app.command(name="list") 19 | def list_strategies(): 20 | """ 21 | List all strategies 22 | """ 23 | typer.echo("Listing strategies") 24 | ListStrategies().handle() -------------------------------------------------------------------------------- /app/storage/strategies/examples/ohlc_demo.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from components import Strategy, on_step 3 | 4 | 5 | class OHLCDemo(Strategy): 6 | 7 | @on_step 8 | def print_ohlc(self): 9 | 10 | # strategy shorthands for the OHLC data 11 | timestamp = self.data.timestamp 12 | close = self.data.close 13 | 14 | # if the timestamp is a multiple of 3600000 (1 hour) 15 | if timestamp % 3600000 == 0: 16 | # strategy a datetime object from the timestamp 17 | dt = datetime.datetime.fromtimestamp(timestamp / 1000) 18 | if dt.hour == 10: 19 | print(f'{dt}: {close}') 20 | 21 | 22 | -------------------------------------------------------------------------------- /app/components/ohlc/symbol.py: -------------------------------------------------------------------------------- 1 | class Symbol: 2 | """ 3 | Represents a tradeable instrument. Feel free to extend this class to fit your needs. 4 | """ 5 | 6 | def __init__(self, symbol: str): 7 | if type(symbol) != str: 8 | raise TypeError('Symbol must be a string.') 9 | self.symbol = symbol 10 | 11 | def __str__(self): 12 | return self.symbol 13 | 14 | 15 | class Equity(Symbol): 16 | """ 17 | Represents an equity instrument. 18 | """ 19 | 20 | def __init__(self, symbol: str): 21 | super().__init__(symbol) 22 | self.cusip = None 23 | self.description = '' 24 | self.exchange = None 25 | self.type = None 26 | -------------------------------------------------------------------------------- /app/tests/test_adapter.py: -------------------------------------------------------------------------------- 1 | class TestDataAdapters: 2 | def test_init_data_adapter(self): 3 | from components.ohlc import DataAdapter 4 | adapter = DataAdapter() 5 | assert adapter.name == 'DataAdapter' 6 | 7 | # test csv adapter 8 | from components.ohlc import CSVAdapter 9 | csv_adapter = CSVAdapter() 10 | assert csv_adapter.name == 'CSVAdapter' 11 | 12 | def test_csv_adapter(self): 13 | from components.ohlc import CSVAdapter 14 | csv_adapter = CSVAdapter() 15 | ohlc = csv_adapter.get_data(None, None, 'tests/data/AAPL.csv', 'AAPL') 16 | assert str(ohlc.symbol) == 'AAPL' 17 | assert ohlc.dataframe.shape == (5001, 5) 18 | -------------------------------------------------------------------------------- /app/components/manager/manager.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | 3 | 4 | class ComponentManager: 5 | 6 | _components = [] 7 | 8 | @classmethod 9 | def register(cls, component): 10 | if component in cls._components: 11 | return 12 | cls._components.append(component) 13 | logger.debug(f'Registered component {component} ({component.__module__})') 14 | 15 | @classmethod 16 | def all(cls): 17 | return [o() for o in cls._components] 18 | 19 | @classmethod 20 | def get(cls, name): 21 | for component in cls._components: 22 | if component.__name__ == name: 23 | return component() 24 | raise ValueError(f'Component {name} not found.') -------------------------------------------------------------------------------- /app/components/strategy/builtins/ta/kalman_filter.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import pandas as pd 4 | from loguru import logger 5 | 6 | from components.strategy import Series 7 | import numpy as np 8 | 9 | 10 | def kalman_filter(src: pd.Series, gain): 11 | src = np.array(src) 12 | n = len(src) 13 | kf = np.zeros(n) 14 | velo = np.zeros(n) 15 | smooth = np.zeros(n) 16 | gain_sqrt = np.sqrt(gain / 5000) 17 | 18 | for i in range(n): 19 | if i > 0: 20 | dk = src[i] - kf[i - 1] 21 | else: 22 | dk = src[i] 23 | smooth[i] = kf[i - 1] + dk * gain_sqrt if i > 0 else src[i] 24 | velo[i] = velo[i - 1] + (gain / 10000) * dk 25 | kf[i] = smooth[i] + velo[i] 26 | return Series(list(kf)) 27 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.8", "3.9", "3.10"] 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | working-directory: ./app 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install pytest 24 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 25 | - name: Test with pytest 26 | working-directory: ./app 27 | run: | 28 | pytest --rootdir=./tests 29 | -------------------------------------------------------------------------------- /app/utils/formatting.py: -------------------------------------------------------------------------------- 1 | 2 | def camel_case(value): 3 | """Converts a string to camel case.""" 4 | return ''.join(x.capitalize() or '_' for x in value.split('_')) 5 | 6 | def snake_case(value): 7 | """Converts a string to snake case.""" 8 | return ''.join(x.lower() if x.islower() else '_' + x.lower() for x in value) 9 | 10 | def pascal_case(value): 11 | """Converts a string to pascal case.""" 12 | return ''.join(x.capitalize() for x in value.split('_')) 13 | 14 | def pascal_to_snake_case(value: str): 15 | """Converts a string from pascal case to snake case.""" 16 | case = '' 17 | for i, c in enumerate(value): 18 | if c.isupper() and i != 0: 19 | case += '_' 20 | case += c.lower() 21 | return case 22 | 23 | def kebab_case(value): 24 | """Converts a string to kebab case.""" 25 | return '-'.join(x.lower() for x in value.split('_')) -------------------------------------------------------------------------------- /app/tests/test_command.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from core.commands.strategy import CreateNewStrategy 4 | 5 | 6 | class TestCommands: 7 | def test_create_strategy(self): 8 | 9 | # Create a new strategy 10 | test_path = 'tests/temp' 11 | CreateNewStrategy(strategy_name='MyTestStrategy').handle(override_path=test_path, prompt=False) 12 | assert os.path.exists(f'{test_path}/my_test_strategy.py') 13 | os.remove(f'{test_path}/my_test_strategy.py') 14 | 15 | # test that exception is raised when strategy name is invalid 16 | try: 17 | CreateNewStrategy(strategy_name='MyTestStrategy').handle(override_path=test_path) 18 | except ValueError: 19 | assert True 20 | 21 | try: 22 | CreateNewStrategy(strategy_name='My Test Strategy').handle(override_path=test_path) 23 | except ValueError: 24 | assert True -------------------------------------------------------------------------------- /app/requirements.txt: -------------------------------------------------------------------------------- 1 | anyio==3.6.1 2 | attrs==22.1.0 3 | certifi==2022.9.24 4 | charset-normalizer==2.1.1 5 | click==8.1.3 6 | colorama==0.4.6 7 | commonmark==0.9.1 8 | fastapi==0.85.0 9 | gunicorn==20.1.0 10 | h11==0.14.0 11 | httptools==0.5.0 12 | idna==3.4 13 | importlib-metadata==6.0.0 14 | iniconfig==1.1.1 15 | Jinja2==3.1.2 16 | llvmlite==0.39.1 17 | loguru==0.6.0 18 | MarkupSafe==2.1.1 19 | numba==0.56.4 20 | numpy==1.23.4 21 | packaging==21.3 22 | pandas==1.5.1 23 | pluggy==1.0.0 24 | py==1.11.0 25 | pydantic==1.10.2 26 | Pygments==2.14.0 27 | pyparsing==3.0.9 28 | pytest==7.1.3 29 | python-dateutil==2.8.2 30 | python-dotenv==0.21.0 31 | pytz==2022.5 32 | PyYAML==6.0 33 | requests==2.28.1 34 | rich==12.6.0 35 | shellingham==1.5.0.post1 36 | six==1.16.0 37 | sniffio==1.3.0 38 | starlette==0.20.4 39 | tomli==2.0.1 40 | typer==0.7.0 41 | typing_extensions==4.3.0 42 | urllib3==2.6.0 43 | uvicorn==0.18.3 44 | uvloop==0.17.0 45 | watchfiles==0.17.0 46 | websockets==10.3 47 | zipp==3.11.0 48 | -------------------------------------------------------------------------------- /.idea/stratis-v2.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | 15 | 20 | 21 | 22 | 24 | -------------------------------------------------------------------------------- /app/components/strategy/builtins/ta/atr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from components.strategy import Series 4 | 5 | 6 | def atr(high: Series, low: Series, close: Series, period: int = 12) -> Series: 7 | """ Calculate the average true range """ 8 | 9 | high = np.array(high) 10 | low = np.array(low) 11 | close = np.array(close) 12 | 13 | if len(high) != len(low) != len(close): 14 | raise ValueError("Input lists must have the same length") 15 | 16 | if len(high) < period: 17 | raise ValueError("Input lists must have at least 'period' number of elements") 18 | 19 | true_range = [] 20 | 21 | for i in range(1, len(high)): 22 | tr = max(high[i] - low[i], abs(high[i] - close[i - 1]), abs(low[i] - close[i - 1])) 23 | true_range.append(tr) 24 | 25 | result = [] 26 | 27 | for i in range(period, len(true_range) + 1): 28 | average = np.mean(true_range[i - period: i]) 29 | result.append(average) 30 | 31 | return Series(result) 32 | -------------------------------------------------------------------------------- /app/components/backtest/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from components.positions import Position 4 | 5 | 6 | def remove_overlapping_positions(positions: List[Position], max_overlap: int = 1): 7 | """Remove overlapping positions from a list of positions.""" 8 | 9 | new_positions = [] 10 | 11 | # Sort the positions by their opened timestamp 12 | sorted_positions = sorted(positions, key=lambda x: x.opened_timestamp) 13 | 14 | i = 0 15 | last_timestamp = None 16 | while i < len(sorted_positions): 17 | p = sorted_positions[i] 18 | if last_timestamp is None: 19 | last_timestamp = p.closed_timestamp 20 | new_positions.append(p) 21 | continue 22 | 23 | no_overlap = last_timestamp < p.opened_timestamp 24 | if no_overlap: 25 | new_positions.append(p) 26 | last_timestamp = p.closed_timestamp 27 | 28 | i += 1 29 | 30 | # Return the modified copy of the list 31 | return new_positions 32 | -------------------------------------------------------------------------------- /app/utils/create_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def create_file(path: str, contents: str, prompt: bool = False) -> tuple: 5 | """Creates a file at the specified path with the specified contents.""" 6 | 7 | # first check that the file doesn't already exist 8 | if prompt: 9 | if os.path.exists(path): 10 | # if the file exists, prompt the user to overwrite it 11 | overwrite = input(f'File already exists at {path}. Overwrite? (y/n): ') 12 | if overwrite.lower() == 'y': 13 | # if the user wants to overwrite it, delete the file 14 | os.remove(path) 15 | else: 16 | print('File not overwritten.') 17 | return path, False 18 | 19 | # strategy the file 20 | try: 21 | with open(path, 'w') as f: 22 | f.write(contents) 23 | f.close() 24 | except Exception as e: 25 | print(f'Error creating file: {e}') 26 | return path, False 27 | return path, True 28 | -------------------------------------------------------------------------------- /app/tests/test_position_testing.py: -------------------------------------------------------------------------------- 1 | from components.backtest.backtest import Backtest 2 | from components.ohlc import CSVAdapter 3 | from storage.strategies.examples.sma_cross_over import SMACrossOver 4 | 5 | 6 | class TestPositions: 7 | def test_large_amount_of_positions(self): 8 | OHLC = CSVAdapter().get_data(start=None, end=None, path='tests/data/AAPL.csv', symbol='AAPL') 9 | strategy = SMACrossOver(data=OHLC) 10 | strategy.data.advance_index(100) 11 | 12 | # create positions 13 | for i in range(500): 14 | strategy.data.advance_index(5) 15 | strategy.positions.open(order_type='market', side='buy', quantity=1) 16 | strategy.data.advance_index(2) 17 | strategy.positions.close() 18 | 19 | # create backtest 20 | backtest = Backtest(strategy=strategy, data=OHLC) 21 | backtest.test() 22 | 23 | # check overview 24 | print(backtest.result.get_overview()) 25 | 26 | assert backtest.result.get_overview().get('trades') == 500 -------------------------------------------------------------------------------- /app/components/strategy/decorators.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | 4 | def extract_decorators(cls): 5 | befores = [] 6 | steps = [] 7 | afters = [] 8 | methods = [method for method in dir(cls) if callable(getattr(cls, method)) and not method.startswith("__")] 9 | for method in methods: 10 | for text, member in inspect.getmembers(getattr(cls, method)): 11 | if text == '__func__': 12 | if 'before' in str(member): 13 | befores.append(method) 14 | if 'step' in str(member): 15 | steps.append(method) 16 | if 'after' in str(member): 17 | afters.append(method) 18 | return befores, steps, afters 19 | 20 | 21 | def on_step(func): 22 | def wrapper(*args, **kwargs): 23 | return func(*args, **kwargs) 24 | return wrapper 25 | 26 | def before(func): 27 | def wrapper(*args, **kwargs): 28 | return func(*args, **kwargs) 29 | return wrapper 30 | 31 | def after(func): 32 | def wrapper(*args, **kwargs): 33 | return func(*args, **kwargs) 34 | return wrapper -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Robert S.W. Carroll 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /app/components/orders/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class TimeInForce(str, Enum): 5 | DAY = "day" 6 | GTC = "gtc" 7 | OPG = "opg" 8 | CLS = "cls" 9 | IOC = "ioc" 10 | FOK = "fok" 11 | 12 | 13 | class OrderType(str, Enum): 14 | MARKET = "market" 15 | LIMIT = "limit" 16 | STOP = "stop" 17 | STOP_LIMIT = "stop_limit" 18 | TRAILING_STOP = "trailing_stop" 19 | 20 | @staticmethod 21 | def abbreviation(order_type): 22 | if order_type == OrderType.MARKET: 23 | return "mkt" 24 | elif order_type == OrderType.LIMIT: 25 | return "lmt" 26 | elif order_type == OrderType.STOP: 27 | return "stp" 28 | elif order_type == OrderType.STOP_LIMIT: 29 | return "stpl" 30 | elif order_type == OrderType.TRAILING_STOP: 31 | return "trsl" 32 | 33 | 34 | class OrderSide(str, Enum): 35 | BUY = "buy" 36 | SELL = "sell" 37 | 38 | @staticmethod 39 | def inverse(side): 40 | if side == OrderSide.BUY: 41 | return OrderSide.SELL 42 | elif side == OrderSide.SELL: 43 | return OrderSide.BUY 44 | -------------------------------------------------------------------------------- /app/components/ohlc/data_adapters/api_adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | import pandas as pd 5 | import requests 6 | 7 | from components.ohlc import DataAdapter, OHLC 8 | from components.ohlc.symbol import Symbol 9 | 10 | 11 | class APIDataAdapter(DataAdapter): 12 | 13 | url = os.getenv('DATA_API_URL', None) 14 | 15 | def get_data( 16 | self, 17 | start: datetime = None, 18 | end: datetime = None, 19 | symbol: str = None, 20 | **kwargs 21 | ): 22 | 23 | if self.url is None: 24 | raise Exception('DATA_API_URL not set') 25 | 26 | # you will have to modify this to get the data from the API, this is just an example 27 | r = requests.get(f'{self.url}/data/{symbol}/ohlc/5?only_completed=true') 28 | r.raise_for_status() 29 | candles = r.json().get('candles', []) 30 | 31 | # strategy a dataframe from the candles 32 | df = pd.DataFrame.from_records(candles) 33 | df.set_index('timestamp', inplace=True) 34 | 35 | # strategy a symbol object 36 | symbol = Symbol(symbol) 37 | 38 | # finally, strategy the OHLC object and return it 39 | ohlc = OHLC(symbol=symbol, dataframe=df) 40 | return ohlc 41 | 42 | -------------------------------------------------------------------------------- /app/components/positions/utils.py: -------------------------------------------------------------------------------- 1 | from components.orders.order import Order 2 | from components.orders.enums import OrderSide as Side 3 | from components.positions.enums import PositionEffect 4 | from components.positions.exceptions import PositionClosedException 5 | 6 | 7 | def add_closing_order_to_position(position, ohlc: 'OHLC'): 8 | """Adds a calculated closing order to the given position.""" 9 | if position.closed: 10 | raise PositionClosedException('Position is already closed') 11 | 12 | # create the closing order 13 | order = Order( 14 | type='market', 15 | side=Side.inverse(position.get_side()), 16 | qty=position.get_size(), 17 | symbol=position.orders[0].symbol, 18 | filled_avg_price=ohlc.close, 19 | timestamp=ohlc.timestamp, 20 | ) 21 | 22 | position.orders.append(order) 23 | 24 | 25 | def show_details(position: 'Position'): 26 | """Prints the details of the given position.""" 27 | print('Position:', position) 28 | for order in position.orders: 29 | print('\tOrder:', order) 30 | 31 | 32 | def get_effect(position: 'Position', order: Order): 33 | """Get the effect of an order on a position.""" 34 | if abs(position.size) < abs(position.size + order.qty): 35 | return PositionEffect.ADD 36 | else: 37 | return PositionEffect.REDUCE 38 | -------------------------------------------------------------------------------- /app/components/orders/signals.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from typing import Optional 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class Signal(BaseModel): 8 | id: Optional[str] = None 9 | order_type: Optional[str] = None 10 | side: Optional[str] = None 11 | quantity: Optional[int] = None 12 | # symbol: Optional[str] = None 13 | price: Optional[float] = None 14 | timestamp: Optional[int] = None 15 | 16 | def from_position(self, position: 'Position'): 17 | self.order_type = position.orders[0].type 18 | self.side = position.get_side() 19 | self.quantity = position.orders[0].qty 20 | self.price = position.orders[0].filled_avg_price 21 | self.timestamp = position.orders[0].timestamp 22 | self.id = self.get_id() 23 | return self 24 | 25 | def get_id(self): 26 | return hashlib.md5(str(self).encode()).hexdigest() 27 | 28 | 29 | class BracketSignal(Signal): 30 | stop_loss: Optional[float] = None 31 | take_profit: Optional[float] = None 32 | 33 | def from_position(self, position: 'Position'): 34 | super().from_position(position) 35 | # TODO: add validation for these 36 | self.stop_loss = [o for o in position.orders if o.type == 'stop'][0].stop_price 37 | self.take_profit = [o for o in position.orders if o.type == 'limit'][0].limit_price 38 | return self 39 | -------------------------------------------------------------------------------- /app/utils/loaders/load_all.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from pathlib import Path 3 | from importlib import import_module 4 | 5 | from loguru import logger 6 | 7 | from components.ohlc import DataAdapter 8 | from components.strategy.strategy import BaseStrategy 9 | 10 | 11 | def import_components(path, component_type): 12 | logger.debug(f'Importing {component_type.__name__}(s) from {path}...') 13 | components = [] 14 | app_path = Path(__file__).parent.parent.parent 15 | paths = app_path.joinpath(path).rglob('*.py') 16 | for path in paths: 17 | module_name = path.as_posix().replace('/', '.').replace('.py', '').split('app.')[1] 18 | module = import_module(module_name) 19 | 20 | for name, obj in inspect.getmembers(module): 21 | if inspect.isclass(obj) and issubclass(obj, component_type): 22 | if obj.__name__ != component_type.__name__: 23 | components.append(obj) 24 | logger.info(f'\t->\t{obj.__name__} ({obj.__module__})') 25 | return components 26 | 27 | # load all components 28 | data_adapters = import_components('components/ohlc/data_adapters', DataAdapter) 29 | strategies = import_components('storage/strategies', BaseStrategy) 30 | 31 | # register all components 32 | for adapter in data_adapters: 33 | adapter.register() 34 | 35 | for strategy in strategies: 36 | strategy.register() -------------------------------------------------------------------------------- /app/api/api_v1/endpoints/data.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from fastapi import APIRouter 4 | from loguru import logger 5 | from pydantic import BaseModel 6 | 7 | from components.ohlc import DataAdapter, CSVAdapter 8 | 9 | router = APIRouter() 10 | 11 | 12 | class DataRequest(BaseModel): 13 | start: Union[int, None] 14 | end: Union[int, None] 15 | kwargs: dict 16 | 17 | class Config: 18 | schema_extra = { 19 | "example": { 20 | "start": None, 21 | "end": None, 22 | "kwargs": { 23 | "path": "data/AAPL.csv" 24 | } 25 | } 26 | } 27 | 28 | 29 | @router.post("/{adapter}") 30 | async def get_data(adapter: str, request: DataRequest): 31 | """Get data from an adapter 32 | May eventually be a "GET" request, see: https://github.com/swagger-api/swagger-ui/issues/2136 33 | """ 34 | a = DataAdapter.objects.get(adapter) 35 | logger.debug(f'Data Endpoint: Using {a.name}') 36 | logger.debug(f'Data Endpoint: start:{request.start} end:{request.end} kwargs:{request.kwargs}') 37 | data = a.get_data(request.start, request.end, **request.kwargs) 38 | return data.to_dict() 39 | 40 | 41 | @router.get("/adapters", tags=["adapter"]) 42 | async def get_adapters(): 43 | """List all data adapters""" 44 | return [a.name for a in DataAdapter.objects.all()] 45 | -------------------------------------------------------------------------------- /app/storage/strategies/examples/sma_cross_over.py: -------------------------------------------------------------------------------- 1 | from components import Parameter 2 | from components import Strategy, on_step 3 | from components.orders.order import Order 4 | from components.strategy import ta 5 | from components.strategy.decorators import after 6 | from components.strategy.strategy import Plot 7 | 8 | 9 | class SMACrossOver(Strategy): 10 | sma_fast_length = Parameter(10) 11 | sma_slow_length = Parameter(60) 12 | 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | all_close = self.data.all('close') 16 | self.sma_fast = ta.sma(all_close, int(self.sma_fast_length)) 17 | self.sma_slow = ta.sma(all_close, int(self.sma_slow_length)) 18 | 19 | @on_step 20 | def check_for_crossover(self): 21 | # add logic to crossover here 22 | cross_over = ta.logic.crossover(self.sma_fast, self.sma_slow) 23 | cross_under = ta.logic.crossunder(self.sma_fast, self.sma_slow) 24 | if cross_over: 25 | # self.orders.market_order(side='buy', quantity=1) 26 | self.positions.open(order_type='market', side='buy', quantity=1) 27 | elif cross_under: 28 | # self.orders.market_order(side='sell', quantity=1) 29 | self.positions.close() 30 | 31 | @after 32 | def create_plots(self): 33 | self.export_plots([ 34 | Plot(self.sma_fast, name='sma_fast'), 35 | Plot(self.sma_slow, name='sma_slow'), 36 | ]) 37 | 38 | -------------------------------------------------------------------------------- /app/tests/test_strategy.py: -------------------------------------------------------------------------------- 1 | from components.ohlc import OHLC 2 | from components.strategy import Strategy 3 | 4 | CSV_PATH = 'tests/data/AAPL.csv' 5 | 6 | 7 | class TestStrategy: 8 | def test_strategy_name(self): 9 | assert Strategy().name == 'BaseStrategy' 10 | 11 | def test_initializing_examples(self): 12 | from storage.strategies.examples.sma_cross_over import SMACrossOver 13 | strategy = SMACrossOver() 14 | assert strategy.name == 'SMACrossOver' 15 | assert int(strategy.sma_fast_length) == 10 16 | assert int(strategy.sma_slow_length) == 60 17 | 18 | def test_run_strategy(self): 19 | from storage.strategies.examples.sma_cross_over import SMACrossOver 20 | ohlc = OHLC.from_csv(CSV_PATH, 'AAPL') 21 | strategy = SMACrossOver(data=ohlc) 22 | strategy.run( 23 | data=ohlc, 24 | ) 25 | 26 | def test_ohlc_demo(self): 27 | from storage.strategies.examples.ohlc_demo import OHLCDemo 28 | strategy = OHLCDemo() 29 | ohlc = OHLC.from_csv(CSV_PATH, 'AAPL') 30 | strategy.run( 31 | data=ohlc, 32 | ) 33 | 34 | def test_load_strategies(self): 35 | from utils.loaders.strategy_loader import import_all_strategies 36 | strategies = import_all_strategies() 37 | assert len(strategies) > 0 38 | 39 | def test_strategy_manager(self): 40 | assert len(Strategy.objects.all()) > 0 41 | assert Strategy.objects.get('SMACrossOver').name == 'SMACrossOver' -------------------------------------------------------------------------------- /app/tests/test_order_manager.py: -------------------------------------------------------------------------------- 1 | from components.ohlc import OHLC, CSVAdapter 2 | from components.ohlc.symbol import Symbol 3 | from components.orders.order import Order 4 | from components.orders.enums import OrderType 5 | from components.orders.order_manager import OrderManager 6 | from components.strategy.strategy import BaseStrategy 7 | from storage.strategies.examples.sma_cross_over import SMACrossOver 8 | 9 | STRATEGY = SMACrossOver( 10 | data=CSVAdapter().get_data(start=None, end=None, path='tests/data/AAPL.csv', symbol='AAPL') 11 | ) 12 | CLOSE = STRATEGY.data.close 13 | TIMESTAMP = STRATEGY.data.timestamp 14 | SYMBOL = STRATEGY.data.symbol.symbol 15 | 16 | 17 | class TestOrderManager: 18 | def test_market_order(self): 19 | om = OrderManager(STRATEGY) 20 | om.market_order(side='buy', quantity=1) 21 | om.market_order(side='sell', quantity=1) 22 | assert len(om) == 2 23 | 24 | def test_add(self): 25 | om = OrderManager(STRATEGY) 26 | om.add( 27 | Order(side='buy', qty=1, symbol=SYMBOL, filled_avg_price=CLOSE, timestamp=TIMESTAMP, type=OrderType.MARKET)) 28 | om.add(Order(side='sell', qty=1, symbol=SYMBOL, filled_avg_price=CLOSE, timestamp=TIMESTAMP, 29 | type=OrderType.MARKET)) 30 | assert len(om) == 2 31 | 32 | def test_summary(self): 33 | om = OrderManager(STRATEGY) 34 | om.market_order(side='buy', quantity=1) 35 | om.market_order(side='sell', quantity=1) 36 | assert om.summary() == { 37 | 'total': 2} 38 | -------------------------------------------------------------------------------- /app/tests/test_ohlc.py: -------------------------------------------------------------------------------- 1 | from components.ohlc import OHLC 2 | from components.ohlc.symbol import Symbol 3 | 4 | PATH = 'tests/data/AAPL.csv' 5 | 6 | 7 | class TestOHLC: 8 | 9 | def test_from_csv(self): 10 | ohlc = OHLC.from_csv(PATH, 'AAPL') 11 | assert isinstance(ohlc, OHLC) 12 | assert isinstance(ohlc.symbol, Symbol) 13 | assert ohlc.symbol.symbol == 'AAPL' 14 | assert ohlc.shape == (5001, 5) 15 | 16 | def test_attr_forwarding(self): 17 | ohlc = OHLC.from_csv(PATH, 'AAPL') 18 | assert ohlc.shape == (5001, 5) 19 | assert ohlc.head().shape == (5, 5) 20 | assert ohlc.tail().shape == (5, 5) 21 | assert ohlc.describe().shape == (8, 5) 22 | 23 | def test_ohlc_getters(self): 24 | ohlc = OHLC.from_csv(PATH, 'AAPL') 25 | assert ohlc.open == 253.91 26 | assert ohlc.high == 257.33 27 | assert ohlc.low == 252.32 28 | assert ohlc.close == 257.33 29 | ohlc.advance_index() 30 | assert ohlc.open == 257.17 31 | assert ohlc.high == 257.67 32 | assert ohlc.low == 256.48 33 | assert ohlc.close == 257.07 34 | 35 | def test_index(self): 36 | ohlc = OHLC.from_csv(PATH, 'AAPL') 37 | assert ohlc._index == 0 38 | ohlc.advance_index() 39 | assert ohlc._index == 1 40 | ohlc.advance_index(2) 41 | assert ohlc._index == 3 42 | ohlc.reset_index() 43 | assert ohlc._index == 0 44 | 45 | def test_interpret_resolution(self): 46 | ohlc = OHLC.from_csv(PATH, 'AAPL') 47 | assert ohlc.resolution == 5 48 | -------------------------------------------------------------------------------- /app/utils/loaders/strategy_loader.py: -------------------------------------------------------------------------------- 1 | # dynamically import all strategies in the storage/strategies folder 2 | import inspect 3 | from pathlib import Path 4 | from importlib import import_module 5 | from typing import List, Type 6 | from loguru import logger 7 | 8 | from components.strategy.strategy import BaseStrategy 9 | 10 | def import_all_strategies() -> List[Type[BaseStrategy]]: 11 | strategies = [] 12 | app_path = Path(__file__).parent.parent.parent 13 | paths = app_path.joinpath('storage/strategies').rglob('*.py') 14 | for path in paths: 15 | # get the module name from the path 16 | module_name = path.as_posix().replace('/', '.').replace('.py', '') 17 | module_name = module_name.split('app.')[1] 18 | # import the module 19 | module = import_module(module_name) 20 | # get all classes in the module 21 | for name, obj in inspect.getmembers(module): 22 | # dynamically import all strategies in the storage/strategies folder 23 | if inspect.isclass(obj) and issubclass(obj, BaseStrategy): 24 | if obj.__name__ != 'BaseStrategy': 25 | strategies.append(obj) 26 | return strategies 27 | 28 | def register_all_strategies(): 29 | strategies = import_all_strategies() 30 | for strategy in strategies: 31 | strategy.objects.register(strategy) 32 | 33 | 34 | loaded_strategies = import_all_strategies() 35 | logger.info(f'Imported {len(loaded_strategies)} strategies') 36 | for s in loaded_strategies: 37 | logger.info(f'\t->\t{s.__name__} ({s.__module__})') 38 | 39 | logger.info('Registering strategies') 40 | register_all_strategies() 41 | logger.info('Strategy loader finished') -------------------------------------------------------------------------------- /app/tests/test_parameter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from components.parameter import IntegerParameter, FloatParameter, BooleanParameter, Parameter 4 | 5 | 6 | class TestParameters: 7 | 8 | def test_create_parameters(self): 9 | make_int_param = Parameter(1, 0, 10) 10 | make_int_value_param = Parameter(1) 11 | make_float_param = Parameter(1.0, 0.0, 10.0) 12 | make_bool_param = Parameter(True) 13 | assert isinstance(make_int_param.value, IntegerParameter) 14 | assert isinstance(make_int_value_param.value, IntegerParameter) 15 | assert isinstance(make_float_param.value, FloatParameter) 16 | assert isinstance(make_bool_param.value, BooleanParameter) 17 | 18 | def test_initialize_parameters(self): 19 | int_param = IntegerParameter(1, 0, 10) 20 | float_param = FloatParameter(1.0, 0.0, 10.0) 21 | bool_param = BooleanParameter(True) 22 | assert int_param.value == 1 23 | assert int_param.min_value == 0 24 | assert int_param.max_value == 10 25 | assert float_param.value == 1.0 26 | assert float_param.min_value == 0.0 27 | assert float_param.max_value == 10.0 28 | assert bool_param.value == True 29 | 30 | with pytest.raises(ValueError): 31 | bad_int_param = IntegerParameter(11, 0, 10) 32 | bad_float_param = FloatParameter(11.0, 0.0, 10.0) 33 | bad_bool_param = BooleanParameter(1) 34 | 35 | 36 | def test_params_str(self): 37 | int_param = IntegerParameter(1, 0, 10) 38 | int_param.name = 'int_param' 39 | bool_param = BooleanParameter(True) 40 | bool_param.name = 'bool_param' 41 | assert str(int_param) == 'int_param : 1 (min_value=0, max_value=10)' 42 | assert str(bool_param) == 'bool_param : True' 43 | 44 | 45 | -------------------------------------------------------------------------------- /app/tests/test_signals.py: -------------------------------------------------------------------------------- 1 | from components.orders.order import Order, StopOrder, LimitOrder 2 | from components.positions.positions import Position 3 | from components.orders.signals import Signal, BracketSignal 4 | 5 | ROOT_ORDER = Order( 6 | type='market', 7 | side='buy', 8 | qty=100, 9 | symbol='AAPL', 10 | filled_avg_price=100, 11 | timestamp=1000 12 | ) 13 | 14 | 15 | class TestSignals: 16 | def test_basic_signal(self): 17 | p = Position( 18 | orders=[ROOT_ORDER], 19 | ) 20 | 21 | p.test() 22 | 23 | s = Signal().from_position(p) 24 | 25 | assert s.order_type == 'market' 26 | assert s.side == 'buy' 27 | assert s.quantity == 100 28 | assert s.price == 100 29 | 30 | def test_bracket_signal(self): 31 | stop_order = StopOrder( 32 | type='stop', 33 | side='sell', 34 | qty=100, 35 | symbol='AAPL', 36 | stop_price=90, 37 | ) 38 | limit_order = LimitOrder( 39 | type='limit', 40 | side='sell', 41 | qty=100, 42 | symbol='AAPL', 43 | limit_price=110, 44 | ) 45 | 46 | p = Position( 47 | orders=[ROOT_ORDER, stop_order, limit_order], 48 | ) 49 | 50 | s = BracketSignal().from_position(p) 51 | 52 | assert s.order_type == 'market' 53 | assert s.side == 'buy' 54 | assert s.price == 100 55 | assert s.stop_loss == 90 56 | assert s.take_profit == 110 57 | 58 | # check that serializing to JSON works as expected 59 | assert s.json() == ('{"id": "3b8cf71b56751e9e58f69ee2650cf483", "order_type": "market", "side": ' 60 | '"buy", "quantity": 100, "price": 100.0, "timestamp": 1000, "stop_loss": ' 61 | '90.0, "take_profit": 110.0}') 62 | -------------------------------------------------------------------------------- /app/components/ohlc/data_adapters/adapter.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import pandas as pd 4 | from loguru import logger 5 | 6 | from components.manager.manager import ComponentManager 7 | from components.ohlc import OHLC 8 | 9 | 10 | class DataAdapterManager(ComponentManager): 11 | _components = [] 12 | 13 | 14 | class DataAdapter: 15 | """Base class for data adapters.""" 16 | 17 | objects = DataAdapterManager 18 | 19 | @classmethod 20 | def register(cls): 21 | cls.objects.register(cls) 22 | 23 | def __init__(self): 24 | self.name = self.__class__.__name__ 25 | 26 | # register 27 | self.register() 28 | 29 | def get_data(self, start: datetime = None, end: datetime = None, *args, **kwargs) -> OHLC: 30 | raise NotImplementedError 31 | 32 | 33 | class CSVAdapter(DataAdapter): 34 | """CSV Adapter, loads data from a csv file.""" 35 | 36 | def get_data( 37 | self, 38 | start: datetime = None, 39 | end: datetime = None, 40 | path: str = None, 41 | symbol: str = None, 42 | ): 43 | """ 44 | Loads data from a csv file. 45 | :param path: path to csv file 46 | :param symbol: symbol, as a string 47 | :param start: start timestamp 48 | :param end: end timestamp 49 | :return: OHLC object 50 | """ 51 | from components.ohlc import OHLC 52 | 53 | if symbol is None: 54 | logger.warning('No symbol provided, using filename as symbol.') 55 | symbol = path.split('/')[-1].split('.')[0] 56 | 57 | ohlc = OHLC.from_csv(path, symbol) 58 | 59 | if start is not None or end is not None: 60 | 61 | # ensure start and end are timestamps 62 | if start is None: 63 | start = ohlc.index[0] 64 | if end is None: 65 | end = ohlc.index[-1] 66 | 67 | ohlc.trim(start, end) 68 | 69 | return ohlc 70 | -------------------------------------------------------------------------------- /app/components/strategy/builtins/ta/correlation.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from components.strategy import Series 7 | 8 | 9 | def correlation_coefficient(x: Union[List[float], Series], y: Union[List[float], Series], period: int) -> Series: 10 | """ Calculate the correlation coefficient between two lists """ 11 | if len(x) != len(y): 12 | raise ValueError("Input arrays must have the same length") 13 | 14 | if len(x) < period: 15 | raise ValueError("Period must be less than or equal to the length of input arrays") 16 | 17 | x = np.asarray(x) 18 | y = np.asarray(y) 19 | 20 | result = np.zeros(len(x) - period + 1) 21 | 22 | # calculate the rolling sums for both x and y. 23 | x_sum = np.cumsum(x) 24 | y_sum = np.cumsum(y) 25 | 26 | # calculate the rolling sums for x * y, x^2, and y^2. 27 | xy_sum = np.cumsum(x * y) 28 | x2_sum = np.cumsum(x**2) 29 | y2_sum = np.cumsum(y**2) 30 | 31 | for i in range(len(result)): 32 | if i == 0: 33 | x_sum_window = x_sum[period - 1] 34 | y_sum_window = y_sum[period - 1] 35 | xy_sum_window = xy_sum[period - 1] 36 | x2_sum_window = x2_sum[period - 1] 37 | y2_sum_window = y2_sum[period - 1] 38 | else: 39 | x_sum_window = x_sum[i + period - 1] - x_sum[i - 1] 40 | y_sum_window = y_sum[i + period - 1] - y_sum[i - 1] 41 | xy_sum_window = xy_sum[i + period - 1] - xy_sum[i - 1] 42 | x2_sum_window = x2_sum[i + period - 1] - x2_sum[i - 1] 43 | y2_sum_window = y2_sum[i + period - 1] - y2_sum[i - 1] 44 | 45 | # Calculate the correlation coefficient for the current window. 46 | numerator = period * xy_sum_window - x_sum_window * y_sum_window 47 | denominator = np.sqrt((period * x2_sum_window - x_sum_window**2) * (period * y2_sum_window - y_sum_window**2)) 48 | result[i] = numerator / denominator 49 | 50 | pad_size = period - 1 51 | result = np.pad(result, (pad_size, 0), mode='constant', constant_values=np.nan) 52 | return Series(list(result)) -------------------------------------------------------------------------------- /app/components/orders/order_manager.py: -------------------------------------------------------------------------------- 1 | from components.orders.order import Order, StopOrder, LimitOrder 2 | 3 | 4 | class OrderManager: 5 | def __init__(self, strategy): 6 | self.orders = [] 7 | self.strategy = strategy 8 | self.symbol = strategy.symbol 9 | if strategy is None: 10 | raise ValueError('Strategy is required') 11 | 12 | 13 | def market_order(self, side: str, quantity: int): 14 | order = Order( 15 | type='market', 16 | side=side, 17 | qty=quantity, 18 | symbol=self.symbol.symbol, 19 | filled_avg_price=self.strategy.data.close, 20 | timestamp=self.strategy.data.timestamp, 21 | ) 22 | self.orders.append(order) 23 | return order 24 | 25 | def stop_loss_order(self, side: str, quantity: int, price: float): 26 | order = StopOrder( 27 | type='stop', 28 | side=side, 29 | qty=quantity, 30 | symbol=self.symbol.symbol, 31 | stop_price=price, 32 | timestamp=None, 33 | ) 34 | self.orders.append(order) 35 | return order 36 | 37 | def limit_order(self, side: str, quantity: int, price: float): 38 | order = LimitOrder( 39 | type='limit', 40 | side=side, 41 | qty=quantity, 42 | symbol=self.symbol.symbol, 43 | limit_price=price, 44 | timestamp=None, 45 | ) 46 | self.orders.append(order) 47 | return order 48 | 49 | def add(self, order: Order): 50 | self.orders.append(order) 51 | 52 | def all(self): 53 | return self.orders 54 | 55 | def filter(self, **kwargs): 56 | return [o for o in self.orders if all([o.__getattribute__(k) == v for k, v in kwargs.items()])] 57 | 58 | def __len__(self): 59 | return len(self.orders) 60 | 61 | def summary(self): 62 | return { 63 | 'total': len(self.orders), 64 | } 65 | 66 | def show(self): 67 | print('showing orders for strategy: {}'.format(self.strategy.name)) 68 | print('\n'.join([str(o) for o in self.orders])) -------------------------------------------------------------------------------- /app/storage/strategies/examples/using_builtins.py: -------------------------------------------------------------------------------- 1 | from components import Parameter 2 | from components import Strategy, on_step 3 | from components.orders.order import Order 4 | from components.strategy import ta 5 | from components.strategy.decorators import after 6 | from components.strategy.strategy import Plot 7 | 8 | 9 | class UsingBuiltins(Strategy): 10 | sma_fast_length = Parameter(10) 11 | sma_slow_length = Parameter(60) 12 | 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | all_close = self.data.all('close') 16 | if len(self.data) > 0: 17 | self.sma_faster = ta.sma(all_close, int(self.sma_fast_length) - 5) 18 | self.sma_fast = ta.sma(all_close, int(self.sma_fast_length)) 19 | self.sma_slow = ta.sma(all_close, int(self.sma_slow_length)) 20 | self.sma_slower = ta.sma(all_close, int(self.sma_slow_length) + 10) 21 | self.atr = ta.atr(self.data.all('high'), self.data.all('low'), all_close, 14) 22 | self.kf_1 = ta.kalman_filter(all_close, 600) 23 | self.kf_2 = ta.kalman_filter(all_close, 600) 24 | self.kf_3 = ta.kalman_filter(all_close, 600) 25 | self.kf_4 = ta.kalman_filter(all_close, 400) 26 | self.kf_5 = ta.kalman_filter(all_close, 500) 27 | self.kf_6 = ta.kalman_filter(all_close, 600) 28 | self.kf_7 = ta.kalman_filter(all_close, 700) 29 | self.correlation = ta.correlation_coefficient(self.sma_fast, self.sma_slow, 14) 30 | 31 | @on_step 32 | def check_for_crossover(self): 33 | # add logic to crossover here 34 | cross_over = ta.logic.crossover(self.sma_fast, self.sma_slow) 35 | cross_under = ta.logic.crossunder(self.sma_fast, self.sma_slow) 36 | corr = self.correlation > 0.5 37 | if cross_over: 38 | # self.orders.market_order(side='buy', quantity=1) 39 | self.positions.open(order_type='market', side='buy', quantity=1) 40 | elif cross_under: 41 | # self.orders.market_order(side='sell', quantity=1) 42 | self.positions.close() 43 | 44 | @after 45 | def create_plots(self): 46 | self.export_plots([ 47 | Plot(self.sma_fast, name='sma_fast'), 48 | Plot(self.sma_slow, name='sma_slow'), 49 | ]) 50 | 51 | -------------------------------------------------------------------------------- /app/components/positions/position_manager.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import List 3 | 4 | from components.ohlc.ohlc import FutureTimestampRequested 5 | from components.orders.order import Order 6 | from components.positions.positions import Position 7 | from components.positions.utils import add_closing_order_to_position 8 | 9 | from loguru import logger 10 | 11 | 12 | class PositionManager: 13 | def __init__(self, strategy: 'BaseStrategy'): 14 | self._strategy = strategy 15 | self.positions: List[Position] = [] 16 | 17 | def add(self, position: Position): 18 | """Adds a position to the manager""" 19 | # TODO: add validation 20 | self.positions.append(position) 21 | 22 | def open(self, order_type: str, side: str, quantity: int): 23 | """Opens a new position""" 24 | 25 | # get the timestamp (it's offset by 1 because we're using the previous close) 26 | try: 27 | timestamp = self._strategy.data.get_timestamp(offset=1) 28 | except FutureTimestampRequested: 29 | logger.warning(f'{self._strategy} is opening a position in the future.') 30 | last_timestamp = self._strategy.data.get_timestamp(offset=0) 31 | resolution = self._strategy.data.resolution 32 | timestamp = last_timestamp + (resolution * 60 * 1000) # convert to milliseconds 33 | logger.warning(f'Extrapolating timestamp to {timestamp} ({datetime.datetime.fromtimestamp(timestamp)})' 34 | f'(resolution: {resolution})') 35 | 36 | # create order 37 | order = Order( 38 | type=order_type, 39 | side=side, 40 | qty=quantity, 41 | symbol=self._strategy.symbol.symbol, 42 | filled_avg_price=self._strategy.data.close, 43 | timestamp=timestamp 44 | ) 45 | 46 | self.positions.append( 47 | Position( 48 | orders=[order], 49 | ) 50 | ) 51 | 52 | def close(self): 53 | """Closes the most recent position""" 54 | try: 55 | position_to_close = self.positions[-1] 56 | add_closing_order_to_position(position=position_to_close, ohlc=self._strategy.data) 57 | except IndexError: 58 | logger.error(f'{self._strategy} has no positions to close') 59 | 60 | def all(self): 61 | return self.positions 62 | -------------------------------------------------------------------------------- /app/storage/strategies/examples/sma_cross_over_advanced.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from components import Parameter 4 | from components import Strategy, on_step 5 | from components.orders.order import Order, LimitOrder, StopOrder 6 | from components.positions.positions import Position 7 | from components.strategy import ta 8 | from components.strategy.decorators import after 9 | from components.strategy.strategy import Plot 10 | 11 | 12 | class SMACrossOverAdvanced(Strategy): 13 | sma_fast_length = Parameter(10) 14 | sma_slow_length = Parameter(100) 15 | 16 | def __init__(self, *args, **kwargs): 17 | super().__init__(*args, **kwargs) 18 | all_close = self.data.all('close') 19 | self.sma_fast = ta.sma(all_close, int(self.sma_fast_length)) 20 | self.sma_slow = ta.sma(all_close, int(self.sma_slow_length)) 21 | 22 | @on_step 23 | def check_for_crossover(self): 24 | # add logic to crossover here 25 | cross_over = ta.logic.crossover(self.sma_fast, self.sma_slow) 26 | # filled timestamp must be set to "Now" + 5 minutes as the order is technically filled at the next candle 27 | filled_timestamp = datetime.datetime.fromtimestamp(self.data.timestamp / 1000) + datetime.timedelta(minutes=5) 28 | if cross_over: 29 | open_order = Order( 30 | type='market', 31 | side='buy', 32 | qty=100, 33 | symbol=self.symbol.symbol, 34 | filled_avg_price=self.data.close, 35 | timestamp=filled_timestamp.timestamp() * 1000, 36 | filled_timestamp=filled_timestamp.timestamp() * 1000, 37 | ) 38 | take_profit = LimitOrder( 39 | side='sell', 40 | qty=100, 41 | symbol=self.symbol.symbol, 42 | limit_price=self.data.close + 1, 43 | ) 44 | stop_loss = StopOrder( 45 | side='sell', 46 | qty=100, 47 | symbol=self.symbol.symbol, 48 | stop_price=self.data.close - 1, 49 | ) 50 | p = Position(orders=[open_order, take_profit, stop_loss]) 51 | self.positions.add(p) 52 | 53 | @after 54 | def create_plots(self): 55 | self.export_plots([ 56 | Plot(self.sma_fast, name='sma_fast'), 57 | Plot(self.sma_slow, name='sma_slow'), 58 | ]) 59 | 60 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 57 | -------------------------------------------------------------------------------- /app/core/commands/strategy/commands.py: -------------------------------------------------------------------------------- 1 | import os 2 | import string 3 | 4 | from components.strategy import Strategy 5 | from core.commands.command import BaseCommand 6 | from utils.commons import STRATEGY_TEMPLATE_PATH 7 | from utils.create_file import create_file 8 | from utils.formatting import snake_case, pascal_case, pascal_to_snake_case 9 | from utils.wiki_link import wiki_link 10 | from rich import print 11 | 12 | 13 | class CreateNewStrategy(BaseCommand): 14 | """Creates a new strategy.""" 15 | 16 | help = ( 17 | "The name of the strategy, in pascal case, to strategy. (ThisIsPascalCase) " 18 | "This will be converted to snake_case and used as the filename. " 19 | "Example: 'MyStrategy' will be converted to 'my_strategy.py'." 20 | ) 21 | 22 | template_path = STRATEGY_TEMPLATE_PATH 23 | 24 | strategy_name: str 25 | 26 | def _validate_strategy_name(self): 27 | 28 | validations = [ 29 | self.strategy_name.isidentifier(), 30 | pascal_case(self.strategy_name).isidentifier(), 31 | pascal_to_snake_case(self.strategy_name).isidentifier(), 32 | ] 33 | 34 | if all(validations): 35 | return True 36 | msg = ( 37 | f'Invalid strategy name: {self.strategy_name}. ' 38 | f'Strategy names must valid. {wiki_link("https://github.com/robswc/stratis/wiki/Strategies#naming")}' 39 | ) 40 | raise ValueError(msg) 41 | 42 | def __init__(self, strategy_name: str): 43 | super().__init__() 44 | self.strategy_name = strategy_name 45 | self._validate_strategy_name() 46 | 47 | def handle(self, *args, **kwargs): 48 | if kwargs.get('override_path'): 49 | strategy_path = f'{kwargs.get("override_path")}/{pascal_to_snake_case(self.strategy_name)}.py' 50 | else: 51 | strategy_path = f'storage/strategies/{pascal_to_snake_case(self.strategy_name)}.py' 52 | 53 | # Open the template file and read its contents 54 | with open(self.template_path, 'r') as template_file: 55 | template_contents = template_file.read() 56 | 57 | # Replace the class name in the template contents with the provided name 58 | new_contents = string.Template(template_contents).substitute( 59 | StrategyName=self.strategy_name 60 | ) 61 | 62 | # strategy the file 63 | path, created = create_file(strategy_path, new_contents, prompt=kwargs.get('prompt', True)) 64 | if created: 65 | print(f'[bold green]Created new strategy at:[/bold green] {path}') 66 | else: 67 | print(f'[bold red]Strategy Not Created[/bold red]') 68 | 69 | class ListStrategies(BaseCommand): 70 | 71 | help = ( 72 | "Lists all registered strategies" 73 | ) 74 | 75 | def handle(self, *args, **kwargs): 76 | return [s.name for s in Strategy.objects.all()] -------------------------------------------------------------------------------- /app/tests/test_order.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import pytest 4 | from pydantic import ValidationError 5 | 6 | from components.orders.order import Order 7 | from components.orders.enums import OrderType, OrderSide as Side 8 | from components.positions.positions import Position 9 | 10 | TIMESTAMP = int(datetime.now().timestamp()) 11 | DETERMINISTIC_TIMESTAMP = 1610000000 12 | 13 | 14 | class TestOrders: 15 | 16 | def test_hashing(self): 17 | fake_order_1 = Order( 18 | symbol='BTCUSDT', 19 | side=Side.BUY, 20 | type=OrderType.MARKET, 21 | qty=1, 22 | timestamp=TIMESTAMP, 23 | ) 24 | fake_order_2 = Order( 25 | symbol='BTCUSDT', 26 | side=Side.BUY, 27 | type=OrderType.MARKET, 28 | qty=1, 29 | timestamp=TIMESTAMP, 30 | ) 31 | assert fake_order_1 == fake_order_2 32 | 33 | fake_order_3 = Order( 34 | symbol='BTCUSDT', 35 | side=Side.BUY, 36 | type=OrderType.MARKET, 37 | qty=2, 38 | timestamp=TIMESTAMP, 39 | ) 40 | 41 | assert fake_order_1 != fake_order_3 42 | 43 | # test IDs 44 | assert fake_order_1.get_id() == fake_order_2.get_id() 45 | assert fake_order_1.get_id() != fake_order_3.get_id() 46 | 47 | def test_order_validation(self): 48 | # valid orders 49 | Order( 50 | symbol='BTCUSDT', 51 | side=Side.BUY, 52 | type=OrderType.MARKET, 53 | timestamp=TIMESTAMP, 54 | qty=1, 55 | ) 56 | 57 | # pytest, test several invalid orders to ensure they all raise a ValidationError 58 | invalid_orders = [ 59 | # missing symbol 60 | {"side": Side.BUY, "type": OrderType.MARKET, "timestamp": TIMESTAMP, "qty": 1}, 61 | # missing side 62 | {"symbol": "BTCUSDT", "type": OrderType.MARKET, "timestamp": TIMESTAMP, "qty": 1}, 63 | # missing qty 64 | {"symbol": "BTCUSDT", "side": Side.BUY, "type": OrderType.MARKET, "timestamp": TIMESTAMP}, 65 | ] 66 | 67 | for data in invalid_orders: 68 | with pytest.raises(ValidationError): 69 | Order(**data) 70 | 71 | 72 | class TestPosition: 73 | def test_position(self): 74 | fake_order_1 = Order( 75 | symbol='BTCUSDT', 76 | side=Side.BUY, 77 | type=OrderType.MARKET, 78 | qty=1, 79 | timestamp=DETERMINISTIC_TIMESTAMP, 80 | ) 81 | fake_order_2 = Order( 82 | symbol='BTCUSDT', 83 | side=Side.SELL, 84 | type=OrderType.MARKET, 85 | qty=1, 86 | timestamp=DETERMINISTIC_TIMESTAMP + 30000, 87 | ) 88 | p = Position(orders=[fake_order_1, fake_order_2]) 89 | assert p._get_id() == "0dbf110bd4db94de539295c65867705d" -------------------------------------------------------------------------------- /app/api/api_v1/endpoints/strategy.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | 4 | from fastapi import APIRouter 5 | from pydantic import BaseModel 6 | from starlette.responses import Response 7 | 8 | from components import Strategy 9 | from components.backtest.backtest import BacktestResult 10 | from components.ohlc import DataAdapter 11 | from loguru import logger 12 | 13 | router = APIRouter() 14 | 15 | 16 | @router.get("/strategy") 17 | async def list_all_strategies(): 18 | """List all strategies""" 19 | strategies = Strategy.objects.all() 20 | return [s.as_model() for s in strategies] 21 | 22 | 23 | @router.get("/") 24 | async def get_strategy(name: str): 25 | """Get a strategy by name""" 26 | try: 27 | strategy = Strategy.objects.get(name=name) 28 | return strategy.as_model() 29 | except ValueError: 30 | return Response(status_code=404) 31 | 32 | 33 | class RunStrategyResponse(BaseModel): 34 | backtest: BacktestResult 35 | plots: List[dict] 36 | 37 | 38 | class RunStrategyRequest(BaseModel): 39 | strategy: str 40 | parameters: dict 41 | adapter: str 42 | adapter_kwargs: dict 43 | 44 | class Config: 45 | schema_extra = { 46 | "example": { 47 | "strategy": "SMACrossOver", 48 | "parameters": {}, 49 | "adapter": "CSVAdapter", 50 | "adapter_kwargs": { 51 | "path": "tests/data/AAPL.csv"} 52 | } 53 | } 54 | 55 | 56 | @router.post("/", response_model=RunStrategyResponse) 57 | async def run_strategy(request: RunStrategyRequest): 58 | """Run a strategy with data""" 59 | 60 | # get arguments from request 61 | name = request.strategy 62 | data_adapter_name = request.adapter 63 | data_adapter_kwargs = request.adapter_kwargs 64 | parameters = request.parameters if request.parameters else {} 65 | 66 | # get strategy and data adapter 67 | try: 68 | da = DataAdapter.objects.get(name=data_adapter_name) 69 | except ValueError: 70 | return Response(status_code=404, content="Data adapter not found") 71 | try: 72 | strategy = Strategy.objects.get(name=name) 73 | except ValueError: 74 | return Response(status_code=404, content="Strategy not found") 75 | 76 | ohlc = da.get_data(**data_adapter_kwargs) 77 | 78 | backtest_result, plots = strategy.run(data=ohlc, parameters=parameters, plots=True) 79 | logger.info(f'Backtest result: {backtest_result.get_overview()}') 80 | logger.info(f'Plots: ({len(plots)})') 81 | return RunStrategyResponse(backtest=backtest_result, plots=[p.as_dict() for p in plots]) 82 | 83 | 84 | class SignalsRequest(BaseModel): 85 | signal_type: str 86 | strategy: RunStrategyRequest 87 | 88 | 89 | @router.post("/signals") 90 | async def run_signals(request: SignalsRequest): 91 | """Run signals""" 92 | print(request.signal_type) 93 | pass 94 | -------------------------------------------------------------------------------- /app/components/parameter.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class ParameterModel(BaseModel): 7 | name: str 8 | value: Union[bool, int, float, str, None] 9 | 10 | 11 | class BaseParameter: 12 | def __init__(self): 13 | self.name = None 14 | self._validate() 15 | 16 | def _validate(self): 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | value = self.__getattribute__('value') 21 | kwargs = [f'{k}={v}' for k, v in self.__dict__.items() if k != 'name' and k != 'value'] 22 | kwargs_str = '' if len(kwargs) == 0 else f' ({", ".join(kwargs)})' 23 | return f'{self.name} : {value}{kwargs_str}' 24 | 25 | def as_model(self): 26 | return ParameterModel( 27 | name=self.name, 28 | value=self.value 29 | ) 30 | 31 | 32 | class Parameter: 33 | def __init__(self, value, *args, **kwargs): 34 | if isinstance(value, bool): 35 | self.value = BooleanParameter(value) 36 | elif isinstance(value, int): 37 | self.value = IntegerParameter(value, *args, **kwargs) 38 | elif isinstance(value, float): 39 | self.value = FloatParameter(value, *args, **kwargs) 40 | else: 41 | raise ValueError('Invalid parameter type') 42 | 43 | def __index__(self): 44 | return self.value.__index__() 45 | 46 | 47 | 48 | 49 | class IntegerParameter(BaseParameter): 50 | def __init__(self, value: int, min_value: int = 0, max_value: int = 9999): 51 | self.value = int(value) 52 | self.min_value = min_value 53 | self.max_value = max_value 54 | super().__init__() 55 | 56 | def __int__(self): 57 | return self.value 58 | 59 | def __index__(self): 60 | return int(self) 61 | 62 | def _validate(self): 63 | if self.value < self.min_value or self.value > self.max_value: 64 | raise ValueError(f'{self} must be between {self.min_value} and {self.max_value}') 65 | 66 | 67 | class FloatParameter(BaseParameter): 68 | def __init__(self, value: float, min_value: float = -9999, max_value: float = 9999): 69 | self.value = float(value) 70 | self.min_value = min_value 71 | self.max_value = max_value 72 | super().__init__() 73 | 74 | def __float__(self): 75 | return self.value 76 | 77 | def __int__(self): 78 | return int(self.value) 79 | 80 | def _validate(self): 81 | if self.value < self.min_value or self.value > self.max_value: 82 | raise ValueError(f'{self} must be between {self.min_value} and {self.max_value}') 83 | # ensure value is a float 84 | self.value = float(self.value) 85 | 86 | def __index__(self): 87 | return float(self) 88 | 89 | 90 | class BooleanParameter(BaseParameter): 91 | def __init__(self, value: bool): 92 | self.value = bool(value) 93 | super().__init__() 94 | 95 | def __bool__(self): 96 | return self.value 97 | 98 | def __index__(self): 99 | return bool(self) 100 | 101 | def _validate(self): 102 | self.value = bool(self.value) 103 | -------------------------------------------------------------------------------- /app/components/strategy/series.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Union 3 | 4 | import pandas as pd 5 | 6 | 7 | # eventually create a base series and have different types of series 8 | class Series: 9 | def __init__(self, data: Union[list, pd.Series]): 10 | self._loop_index = 0 11 | self._data = data 12 | self._index_func = { 13 | list: lambda: data[self._loop_index], 14 | pd.Series: lambda: data.iat[self._loop_index] 15 | }[type(self._data)] 16 | self._shift_func = { 17 | list: lambda n: data[self._loop_index - n], 18 | pd.Series: lambda n: data.iat[self._loop_index - n] 19 | }[type(self._data)] 20 | 21 | def advance_index(self): 22 | self._loop_index += 1 23 | 24 | def as_list(self): 25 | if isinstance(self._data, list): 26 | return self._data 27 | if isinstance(self._data, pd.Series): 28 | df = self._data.copy() 29 | # replace NaN with previous value 30 | df.fillna(method='backfill', inplace=True) 31 | return df.tolist() 32 | 33 | def __repr__(self): 34 | return str(float(self)) 35 | 36 | def __len__(self): 37 | return len(self._data) 38 | 39 | def __getitem__(self, item): 40 | if isinstance(self._data, list): 41 | return self._data[item] 42 | if isinstance(self._data, pd.Series): 43 | return self._data.iat[item] 44 | 45 | def __float__(self): 46 | return float(self._index_func()) 47 | 48 | def shift(self, n=1): 49 | return float(self._shift_func(n)) 50 | 51 | def __int__(self): 52 | return int(float(self)) 53 | 54 | def __add__(self, other): 55 | return float(self) + other 56 | 57 | def __sub__(self, other): 58 | return float(self) - other 59 | 60 | def __mul__(self, other): 61 | return float(self) * other 62 | 63 | def __truediv__(self, other): 64 | return float(self) / other 65 | 66 | def __floordiv__(self, other): 67 | return float(self) // other 68 | 69 | def __mod__(self, other): 70 | return float(self) % other 71 | 72 | def __pow__(self, other): 73 | return float(self) ** other 74 | 75 | def __lt__(self, other): 76 | return float(self) < float(other) 77 | 78 | def __le__(self, other): 79 | return float(self) <= other 80 | 81 | def __eq__(self, other): 82 | return float(self) == other 83 | 84 | def __ne__(self, other): 85 | return float(self) != other 86 | 87 | def __gt__(self, other): 88 | return float(self) > float(other) 89 | 90 | def __ge__(self, other): 91 | return float(self) >= other 92 | 93 | def __and__(self, other): 94 | return float(self) and other 95 | 96 | def __or__(self, other): 97 | return float(self) or other 98 | 99 | def __neg__(self): 100 | return -float(self) 101 | 102 | def __pos__(self): 103 | return +float(self) 104 | 105 | def __abs__(self): 106 | return abs(float(self)) 107 | 108 | def __invert__(self): 109 | return ~float(self) 110 | 111 | def __round__(self, n=None): 112 | return round(float(self), n) 113 | 114 | def __trunc__(self): 115 | return math.trunc(float(self)) 116 | 117 | def __floor__(self): 118 | return math.floor(float(self)) 119 | 120 | def __ceil__(self): 121 | return math.ceil(float(self)) 122 | 123 | def __index__(self): 124 | return float(self) 125 | 126 | def __radd__(self, other): 127 | return other + float(self) 128 | 129 | def __rsub__(self, other): 130 | return other - float(self) 131 | 132 | def __rmul__(self, other): 133 | return other * float(self) 134 | 135 | def __rtruediv__(self, other): 136 | return other / float(self) 137 | -------------------------------------------------------------------------------- /app/components/backtest/backtest.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import queue 3 | from typing import Optional, List, Union 4 | 5 | from loguru import logger 6 | from pydantic import BaseModel 7 | 8 | from components.backtest.utils import remove_overlapping_positions 9 | from components.orders.order import Order 10 | from components.positions.positions import Position 11 | 12 | 13 | def get_effect(position: Position, order: Order): 14 | """Get the effect of an order on a position.""" 15 | if position.get_size() > 0: 16 | if order.side == 'buy': 17 | return 'increase' 18 | elif order.side == 'sell': 19 | return 'decrease' 20 | elif position.get_size() < 0: 21 | if order.side == 'buy': 22 | return 'decrease' 23 | elif order.side == 'sell': 24 | return 'increase' 25 | else: 26 | return 'increase' 27 | 28 | 29 | def worker(q, data): 30 | while not q.empty(): 31 | try: 32 | position = q.get_nowait() 33 | position.test(data) 34 | except queue.Empty: 35 | break 36 | finally: 37 | q.task_done() 38 | 39 | 40 | class BacktestResult(BaseModel): 41 | pnl: float 42 | wl_ratio: float 43 | sharpe_ratio: Optional[float] 44 | max_drawdown: Optional[float] 45 | max_drawdown_duration: Optional[int] 46 | trades: int 47 | winning_trades: int 48 | losing_trades: int 49 | positions: List[Position] 50 | orders: List[Order] 51 | 52 | def get_overview(self): 53 | """Get a dict with the most important backtest results.""" 54 | return { 55 | 'pnl': self.pnl, 56 | 'wl_ratio': self.wl_ratio, 57 | 'trades': self.trades, 58 | 'winning_trades': self.winning_trades, 59 | 'losing_trades': self.losing_trades, 60 | } 61 | 62 | 63 | class Backtest: 64 | def __init__(self, data, strategy): 65 | self.data = data 66 | self.strategy = strategy 67 | self.result: Union[BacktestResult, None] = None 68 | 69 | def _get_orders_with_filled_timestamp(self, orders): 70 | return [o for o in orders if o.filled_timestamp is not None] 71 | 72 | def _sort_orders(self, orders: List[Order]): 73 | return sorted(orders, key=lambda x: x.timestamp) 74 | 75 | def test(self): 76 | logger.debug(f'Starting backtest for strategy {self.strategy.name}') 77 | 78 | positions = self.strategy.positions.all() 79 | orders = self.strategy.orders.all() 80 | 81 | # create a bounded queue to hold the positions 82 | position_queue = queue.Queue() 83 | for position in positions: 84 | position_queue.put(position) 85 | 86 | # use concurrent futures to test orders in parallel 87 | logger.debug(f'Testing {len(positions)} positions in parallel...') 88 | with concurrent.futures.ThreadPoolExecutor() as executor: 89 | for _ in range(len(positions)): 90 | executor.submit(worker, position_queue, self.data) 91 | 92 | # wait for all positions to be tested 93 | position_queue.join() 94 | 95 | logger.debug(f'Finished testing {len(positions)} positions in parallel.') 96 | 97 | # after all positions have been tested, we can check for overlapping positions 98 | positions = remove_overlapping_positions(positions, max_overlap=0) 99 | 100 | all_position_orders = [] 101 | for p in positions: 102 | all_position_orders += p.orders 103 | 104 | # calculate win/loss ratio 105 | losing_trades = len([p for p in positions if p.pnl < 0]) 106 | winning_trades = len([p for p in positions if p.pnl > 0]) 107 | if losing_trades == 0: 108 | wl_ratio = 1 109 | elif winning_trades == 0: 110 | wl_ratio = 0 111 | else: 112 | wl_ratio = round(winning_trades / (winning_trades + losing_trades), 2) 113 | 114 | # filter and sort orders 115 | result_orders = self._get_orders_with_filled_timestamp(orders + all_position_orders) 116 | # sort ascending by timestamp 117 | result_orders = self._sort_orders(result_orders) 118 | 119 | # create backtest result 120 | self.result = BacktestResult( 121 | pnl=sum([position.pnl for position in positions]), 122 | wl_ratio=wl_ratio, 123 | trades=len(positions), 124 | winning_trades=winning_trades, 125 | losing_trades=losing_trades, 126 | positions=positions, 127 | orders=result_orders, 128 | ) 129 | -------------------------------------------------------------------------------- /app/components/orders/order.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import hashlib 3 | from typing import Optional, Union 4 | 5 | from loguru import logger 6 | from pydantic import BaseModel, ValidationError, validator 7 | 8 | from components.orders.enums import TimeInForce, OrderType, OrderSide 9 | 10 | """ 11 | Similar to Alpaca-py 12 | """ 13 | 14 | 15 | class OrderValidationError(Exception): 16 | pass 17 | 18 | 19 | class Order(BaseModel): 20 | id: Optional[str] 21 | timestamp: Union[int, None] 22 | symbol: str 23 | qty: int 24 | notional: Optional[float] 25 | side: OrderSide 26 | type: Optional[OrderType] 27 | time_in_force: TimeInForce = TimeInForce.GTC 28 | extended_hours: Optional[bool] 29 | client_order_id: Optional[str] 30 | filled_avg_price: Optional[float] 31 | filled_timestamp: Optional[int] 32 | did_not_fill: Optional[bool] = False 33 | 34 | @staticmethod 35 | def create_market_order(symbol: str, qty: float, side: OrderSide): 36 | return Order( 37 | symbol=symbol, 38 | qty=qty, 39 | side=side, 40 | type=OrderType.MARKET, 41 | ) 42 | 43 | def _timestamp_to_datetime(self, timestamp: int): 44 | if timestamp is not None: 45 | return datetime.datetime.fromtimestamp(timestamp / 1000).strftime('%Y-%m-%d %H:%M:%S') 46 | else: 47 | return 'TBD' 48 | 49 | def __str__(self): 50 | order_type = OrderType.abbreviation(self.type).upper() 51 | side = self.side.upper() 52 | return f'{order_type} {side} [{abs(self.qty)}] {self.symbol} @ {self.filled_avg_price}\t' \ 53 | f'({self._timestamp_to_datetime(self.timestamp)})' 54 | 55 | def __hash__(self): 56 | h = hashlib.sha256(f'{self.timestamp}{self.symbol}{self.qty}{self.side}{self.type}'.encode()) 57 | return int(h.hexdigest(), 16) 58 | 59 | def get_id(self): 60 | return hashlib.md5(str(hash(self)).encode()).hexdigest() 61 | 62 | # will eventually make this a proper attribute 63 | @property 64 | def price(self): 65 | return self.filled_avg_price 66 | 67 | class Config: 68 | schema_extra = { 69 | "example": { 70 | "id": "b6b6b6b6-b6b6-b6b6-b6b6-b6b6b6b6b6b6", 71 | "symbol": "AAPL", 72 | "qty": 100, 73 | "side": "buy", 74 | "type": "market", 75 | } 76 | } 77 | 78 | def __init__(self, **data): 79 | try: 80 | super().__init__(**data) 81 | except ValidationError as e: 82 | logger.error(e) 83 | # logger.exception(e) 84 | raise e 85 | 86 | # if valid, set id 87 | self.id = self.get_id() 88 | self.type = self.type or OrderType.MARKET 89 | self.filled_timestamp = self.timestamp 90 | 91 | # if side is sell, qty must be negative 92 | if self.side == OrderSide.SELL: 93 | self.qty = -abs(self.qty) 94 | else: 95 | self.qty = abs(self.qty) 96 | 97 | @validator('qty') 98 | def qty_must_be_int(cls, v): 99 | assert v != 0, 'qty cannot be 0' 100 | return v 101 | 102 | 103 | class LimitOrder(Order): 104 | limit_price: float 105 | 106 | @property 107 | def price(self): 108 | return self.limit_price 109 | 110 | class Config: 111 | schema_extra = { 112 | "example": { 113 | "id": "b6b6b6b6-b6b6-b6b6-b6b6-b6b6b6b6b6b6", 114 | "symbol": "AAPL", 115 | "qty": 100, 116 | "side": "buy", 117 | "type": "limit", 118 | "limit_price": 100.00, 119 | } 120 | } 121 | 122 | def __init__(self, **data): 123 | super().__init__(**data) 124 | self.type = OrderType.LIMIT 125 | 126 | def __str__(self): 127 | order_type = OrderType.abbreviation(self.type).upper() 128 | side = self.side.upper() 129 | return f'{order_type} {side} [{self.qty}] {self.symbol} @ {self.limit_price}\t({self._timestamp_to_datetime(self.timestamp)})' 130 | 131 | 132 | class StopOrder(Order): 133 | stop_price: Optional[float] 134 | 135 | @property 136 | def price(self): 137 | return self.stop_price 138 | 139 | class Config: 140 | schema_extra = { 141 | "example": { 142 | "id": "b6b6b6b6-b6b6-b6b6-b6b6-b6b6b6b6b6b6", 143 | "symbol": "AAPL", 144 | "qty": 100, 145 | "side": "buy", 146 | "type": "stop", 147 | "stop_price": 100.00, 148 | } 149 | } 150 | 151 | def __init__(self, **data): 152 | super().__init__(**data) 153 | self.type = OrderType.STOP 154 | 155 | def __str__(self): 156 | order_type = OrderType.abbreviation(self.type).upper() 157 | side = self.side.upper() 158 | return f'{order_type} {side} [{self.qty}] {self.symbol} @ {self.stop_price}\t({self._timestamp_to_datetime(self.timestamp)})' 159 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | stratis-full-logo 3 | stratis-full-logo 4 |

5 | 6 | 7 | # Stratis 8 | 9 | [![License](https://img.shields.io/github/license/robswc/stratis?style=for-the-badge)](https://github.com/robswc/stratis/blob/master/LICENSE) 10 | [![Build](https://img.shields.io/github/actions/workflow/status/robswc/stratis/pytest.yml?style=for-the-badge)]() 11 | [![GitHub repo size](https://img.shields.io/github/repo-size/robswc/stratis?style=for-the-badge)](https://github.com/robswc/stratis) 12 | [![Stars](https://img.shields.io/github/stars/robswc/stratis?style=for-the-badge)](https://github.com/robswc/stratis/stargazers) 13 | [![Twitter Follow](https://img.shields.io/twitter/follow/robswc?label=Twitter!&style=for-the-badge)](https://twitter.com/robswc) 14 | 15 | 16 | 17 | 18 | Stratis is a python-based framework for developing and testing strategies, inspired by the simplicity 19 | of tradingview's [Pinescript](https://www.tradingview.com/pine-script-docs/en/v5/Introduction.html). Currently, 20 | stratis is in the early stages of development, and is not yet ready for production use. However, it is encouraged 21 | to try it out and provide feedback via [GitHub issues](https://github.com/robswc/stratis/issues/new). 22 | 23 | Stratis is a part of [Shenandoah Research's](https://shenandoah.capital/) Open source Trading Software Initiative and developed by Polyad Decision Sciences software engineers! 24 | 25 | 26 | sr-logo 27 | poly-ad-logo 28 | 29 | 30 | #### _Please Note Stratis is under active development and is not yet ready for production use._ 31 | 32 | 33 | ## Basic Example 34 | 35 | The following code demonstrates how to create a strategy that prints the timestamp and close price of the 36 | OHLC data every hour. The `on_step` decorator is used to run the function on every step of the OHLC data. You can find 37 | more info about how to create strategies [here](https://github.com/robswc/stratis/wiki/Strategies). 38 | 39 | Using the `Strategy` 40 | class, along with the `on_step` decorator, you can create strategies that are as simple or as complex as you want, with 41 | the full power of python at your disposal. 42 | 43 | ```python 44 | class OHLCDemo(Strategy): 45 | 46 | @on_step # on_step decorated, runs every "step" of the OHLC data 47 | def print_ohlc(self): 48 | 49 | # shorthands for the OHLC data 50 | timestamp = self.data.timestamp 51 | close = self.data.close 52 | 53 | # if the timestamp is a multiple of 3600000 (1 hour) 54 | if timestamp % 3600000 == 0: 55 | # create a datetime object from the timestamp 56 | dt = datetime.datetime.fromtimestamp(timestamp / 1000) 57 | if dt.hour == 10: 58 | print(f'{dt}: {close}') 59 | ``` 60 | 61 | ```python 62 | data = CSVAdapter('data/AAPL.csv') 63 | strategy = OHLCDemo().run(data) 64 | ``` 65 | 66 | 67 | ## Table of Contents 68 | 69 | - [Installation](#Installation) 70 | - [Docker](#Docker) 71 | - [Python and NPM](#Python-and-NPM) 72 | 73 | [//]: # (- [Features](#features)) 74 | 75 | ## Installation 76 | 77 | It is heavily recommended to use [Docker](https://www.docker.com/resources/what-container/) to run stratis. This is because stratis requires a number of dependencies that 78 | can be difficult to install without Docker. 79 | 80 | 81 | ### Docker 82 | To install Docker, follow the instructions [here](https://docs.docker.com/get-docker/). 83 | 84 | Once Docker is installed, you can run stratis by running the following commands: 85 | 86 | ```bash 87 | # Clone the repository 88 | git clone https://github.com/robswc/stratis 89 | 90 | # Change directory to the repository 91 | cd stratis 92 | 93 | # Run the docker-compose file 94 | docker-compose up -d # -d runs the containers in the background 95 | ``` 96 | 97 | The Stratis UI interface should now be accessible via: 98 | 99 | [http://localhost:3000](http://localhost:3000) 100 | 101 | And the Stratis backend (core) should be accessible via: 102 | 103 | [http://localhost:8000](http://localhost:8000) 104 | 105 | ### Python and NPM 106 | 107 | For more advanced usage, you can run app with python directly, as it is a FastAPI app under the hood. 108 | Please note, this may or may not work for Windows and MacOS, as I have only tested it on Linux. 109 | 110 | I would recommend using a [virtual environment](https://docs.python.org/3/library/venv.html) for this. 111 | You will also have to [install the requirements](https://pip.pypa.io/en/latest/user_guide/#requirements-files). 112 | The following commands will start the backend of stratis. 113 | 114 | ```bash 115 | cd app # change directory to the app folder 116 | python python -m uvicorn main:app --reload # reloads the app on file changes (useful for development) 117 | ``` 118 | 119 | The frontend of stratis is a NextJS app. The repository for the frontend can be found [here](https://github.com/robswc/stratis-ui). -------------------------------------------------------------------------------- /app/components/strategy/strategy.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import List, Union 3 | 4 | import pandas as pd 5 | from loguru import logger 6 | from pydantic import BaseModel 7 | 8 | from components.backtest.backtest import Backtest 9 | from components.manager.manager import ComponentManager 10 | from components.ohlc import OHLC 11 | from components.orders.order_manager import OrderManager 12 | from components.positions.position_manager import PositionManager 13 | from components.parameter import BaseParameter, Parameter, ParameterModel 14 | from components.strategy.decorators import extract_decorators 15 | 16 | 17 | class PlotConfig(BaseModel): 18 | color: str = 'blue' 19 | type: str = 'line' 20 | lineWidth: int = 1 21 | 22 | 23 | class Plot: 24 | def __init__(self, series: 'Series', **kwargs): 25 | self.data = series.as_list() 26 | self.name = kwargs.get('name', None) 27 | self.config = PlotConfig(**kwargs) 28 | 29 | def as_dict(self): 30 | return { 31 | 'name': self.name, 32 | 'data': self.data, 33 | 'config': self.config.dict() 34 | } 35 | 36 | 37 | class StrategyManager(ComponentManager): 38 | _components = [] 39 | 40 | 41 | class StrategyModel(BaseModel): 42 | name: str 43 | parameters: List[ParameterModel] 44 | 45 | 46 | class BaseStrategy: 47 | objects = StrategyManager 48 | 49 | @classmethod 50 | def register(cls): 51 | cls.objects.register(cls) 52 | 53 | def __init__(self, data: Union[OHLC, None] = None): 54 | self.name = self.__class__.__name__ 55 | 56 | # handle data 57 | if data is None: 58 | data = OHLC() 59 | self.data = data 60 | self._loop_index = 0 61 | 62 | # create a shortcut to the symbol 63 | print(self.data.symbol) 64 | self.symbol = data.symbol 65 | 66 | # handle parameters 67 | self.parameters: List[BaseParameter] = [] 68 | self._set_parameters() 69 | 70 | self.register() 71 | 72 | # strategy decorators 73 | self._step_methods = [] 74 | self._before_methods = [] 75 | self._after_methods = [] 76 | 77 | befores, steps, afters = extract_decorators(self) 78 | self._before_methods = befores 79 | self._step_methods = steps 80 | self._after_methods = afters 81 | 82 | # each strategy gets a new order and position manager 83 | self.orders = OrderManager(self) # eventually all orders will be converted to positions 84 | self.positions = PositionManager(self) 85 | 86 | # each strategy gets plots 87 | self.plots = [] 88 | 89 | def export_plots(self, plots: List[Plot]): 90 | self.plots = plots 91 | 92 | def as_model(self) -> StrategyModel: 93 | return StrategyModel( 94 | name=self.name, 95 | parameters=[p.as_model() for p in self.parameters], 96 | ) 97 | 98 | def _get_all_parameters(self): 99 | parameters = [] 100 | for attr in dir(self): 101 | if isinstance(self.__getattribute__(attr), BaseParameter): 102 | self.parameters.append(self.__getattribute__(attr)) 103 | if isinstance(self.__getattribute__(attr), Parameter): 104 | p = self.__getattribute__(attr).value 105 | p.name = attr 106 | parameters.append(p) 107 | return parameters 108 | 109 | def _set_parameters(self): 110 | # find all parameters in the class 111 | for attr in dir(self): 112 | if isinstance(self.__getattribute__(attr), BaseParameter): 113 | self.parameters.append(self.__getattribute__(attr)) 114 | if isinstance(self.__getattribute__(attr), Parameter): 115 | p = self.__getattribute__(attr).value 116 | p.name = attr 117 | self.parameters.append(p) 118 | 119 | def show_parameters(self): 120 | return '\n'.join([str(p) for p in self.parameters]) 121 | 122 | def _setup_data(self, data: OHLC): 123 | self.data = data 124 | 125 | def _create_series(self): 126 | for attr in dir(self): 127 | if isinstance(self.__getattribute__(attr), list): 128 | if attr not in ['_before_methods', '_step_methods', '_after_methods', 'parameters']: 129 | sys.modules['components.strategy.series'].Series(self.__getattribute__(attr)) 130 | if isinstance(self.__getattribute__(attr), pd.Series): 131 | sys.modules['components.strategy.series'].Series(self.__getattribute__(attr).to_list()) 132 | 133 | def _get_all_series_data(self): 134 | series = [] 135 | for attr in dir(self): 136 | if isinstance(self.__getattribute__(attr), sys.modules['components.strategy.series'].Series): 137 | series.append(self.__getattribute__(attr)) 138 | return series 139 | 140 | def _get_all_plots(self): 141 | # will use in the future to get all plots 142 | plots = [] 143 | for attr in dir(self): 144 | if isinstance(self.__getattribute__(attr), Plot): 145 | plots.append(self.__getattribute__(attr)) 146 | return plots 147 | 148 | def run(self, data: OHLC, parameters: dict = None, **kwargs): 149 | 150 | # set parameters 151 | if parameters is not None: 152 | for p in self.parameters: 153 | if p.name in parameters: 154 | p.value = parameters[p.name] 155 | 156 | self.__init__(data=data) 157 | self._setup_data(data) 158 | self._create_series() 159 | 160 | # run before methods 161 | for method in self._before_methods: 162 | getattr(self, method)() 163 | 164 | # run step methods 165 | 166 | series = self._get_all_series_data() 167 | 168 | for i in range(len(self.data.dataframe)): 169 | for method in self._step_methods: 170 | getattr(self, method)() 171 | 172 | # advance the index of all series 173 | for s in series: 174 | s.advance_index() 175 | 176 | # advance the index of the data 177 | self.data.advance_index() 178 | 179 | # handle backtest 180 | b = Backtest(strategy=self, data=data) 181 | 182 | # runs the backtest 183 | if self.positions.positions or self.orders.orders: 184 | b.test() 185 | 186 | # run after methods 187 | for method in self._after_methods: 188 | getattr(self, method)() 189 | 190 | # get all plots 191 | plots = self.plots 192 | 193 | if kwargs.get('plots', False): 194 | logger.debug(f'Requested plots, found {len(plots)}') 195 | return b.result, plots 196 | return b.result 197 | -------------------------------------------------------------------------------- /app/components/ohlc/ohlc.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import random 3 | from typing import Union 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from loguru import logger 8 | 9 | from components.ohlc.symbol import Symbol 10 | 11 | EMPTY_DATA = pd.DataFrame(columns=['open', 'high', 'low', 'close', 'volume', 'timestamp']) 12 | EMPTY_DATA.set_index('timestamp', inplace=True) 13 | 14 | 15 | class FutureTimestampRequested(Exception): 16 | def __init__(self, dataframe, index): 17 | super().__init__(f'Future timestamp requested. You must ensure the OHLCs resolution is set for future ' 18 | f'timestamp extrapolation.\n' 19 | f'Dataframe length: {len(dataframe)}, index: {index}\n') 20 | 21 | 22 | class OHLC: 23 | """ 24 | OHLCV data class. This class is used to store OHLCV data. 25 | Wraps around a pandas dataframe. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | symbol: Symbol = None, 31 | dataframe: pd.DataFrame = None, 32 | resolution: Union[int, None] = None, 33 | ): 34 | self.symbol = symbol 35 | 36 | # if no dataframe is provided, use an empty dataframe 37 | if dataframe is None: 38 | dataframe = EMPTY_DATA 39 | self.dataframe = dataframe 40 | 41 | # if no resolution is provided, attempt to interpret it 42 | if resolution is None: 43 | try: 44 | self._interpret_resolution() 45 | except ValueError: 46 | logger.error(f'Unable to interpret resolution for {self.symbol}') 47 | self.resolution = 0 48 | else: 49 | self.resolution = resolution 50 | self._index = 0 51 | 52 | # if a dataframe is provided, validate it 53 | if dataframe is not None: 54 | self._validate() 55 | 56 | def _interpret_resolution(self): 57 | """Attempts to interpret the resolution of the OHLC data.""" 58 | # generate a sample size that is 25% of the data 59 | sample_size = len(self.dataframe) // 4 60 | 61 | # generate a set of random indexes to sample 62 | indexes = random.sample(range(len(self.dataframe) - 1), sample_size) 63 | 64 | # take random sample pairs and calculate the difference between timestamps 65 | diffs = [self.dataframe.index[i + 1] - self.dataframe.index[i] for i in indexes] 66 | 67 | # get the most common difference, convert to minutes 68 | self.resolution = int(max(set(diffs), key=diffs.count) / 1000 / 60) 69 | 70 | def advance_index(self, n: int = 1): 71 | self._index += n 72 | 73 | def reset_index(self): 74 | self._index = 0 75 | 76 | def _get_ohlc(self, column: str, index: int = None): 77 | if index is None: 78 | index = self._index 79 | try: 80 | value = self.dataframe[column].iat[index] 81 | except IndexError: 82 | logger.error(f'Index out of range. Index: {index}, Length: {len(self.dataframe)}') 83 | value = None 84 | return value 85 | 86 | @property 87 | def open(self): 88 | return self._get_ohlc('open') 89 | 90 | @property 91 | def high(self): 92 | return self._get_ohlc('high') 93 | 94 | @property 95 | def low(self): 96 | return self._get_ohlc('low') 97 | 98 | @property 99 | def close(self): 100 | return self._get_ohlc('close') 101 | 102 | @property 103 | def volume(self): 104 | return self.dataframe['volume'] 105 | 106 | @property 107 | def timestamp(self): 108 | return self.dataframe.index[self._index] 109 | 110 | def get_timestamp(self, offset: int = 0): 111 | try: 112 | return self.dataframe.index[self._index + offset] 113 | except IndexError: 114 | if self._index + offset == len(self.dataframe): 115 | raise FutureTimestampRequested(self.dataframe, self._index + offset) 116 | 117 | def all(self, column: str): 118 | try: 119 | return self.dataframe[column] 120 | except KeyError: 121 | return [] 122 | 123 | def __str__(self): 124 | return f'OHLC: {self.symbol}' 125 | 126 | def __getattr__(self, item): 127 | # check if the attribute is part of the class 128 | if item in self.__dict__: 129 | return self.__dict__[item] 130 | # else, forward the attribute to the dataframe 131 | else: 132 | return getattr(self.dataframe, item) 133 | 134 | def _validate(self): 135 | # ensure data has the correct columns 136 | if not {'open', 'high', 'low', 'close', 'volume'}.issubset(self.dataframe.columns): 137 | raise ValueError(f'Invalid data. Missing columns. Expected: open, high, low, close, volume.') 138 | 139 | # ensure the index of data is 'timestamp' 140 | if self.dataframe.index.name != 'timestamp': 141 | raise ValueError(f'Invalid data. Index must be named "timestamp", not "{self.dataframe.index.name}".') 142 | 143 | @staticmethod 144 | def from_csv(path: str, symbol: str): 145 | """ 146 | Loads data from a csv file. 147 | :param symbol: symbol for the data 148 | :param path: Path to csv file. 149 | :return: self 150 | """ 151 | 152 | # strategy a symbol from the symbol string 153 | symbol = Symbol(symbol) 154 | 155 | # load the data from the csv file 156 | dataframe = pd.read_csv(path) 157 | dataframe.set_index('timestamp', inplace=True) 158 | 159 | # create and validate OHLC object 160 | ohlc = OHLC(symbol=symbol, dataframe=dataframe) 161 | ohlc._validate() 162 | 163 | return ohlc 164 | 165 | def trim(self, start: int, end: int): 166 | """ 167 | Trims the OHLC data to the specified range. 168 | :param start: Start index. 169 | :param end: End index. 170 | :return: self 171 | """ 172 | 173 | start_dt = datetime.datetime.fromtimestamp(start / 1000) 174 | end_dt = datetime.datetime.fromtimestamp(end / 1000) 175 | 176 | df = self.dataframe.copy() 177 | logger.debug(f'Trimming data from {start_dt} to {end_dt}') 178 | logger.debug(f'Original data length: {len(df)}') 179 | df = df.loc[start:end] 180 | self.dataframe = df 181 | self._validate() 182 | logger.debug(f'Trimmed data length: {len(df)}') 183 | return self 184 | 185 | def __len__(self): 186 | return len(self.dataframe) 187 | 188 | def to_dict(self): 189 | df = self.dataframe.copy() 190 | df['time'] = df.index 191 | return { 192 | 'symbol': self.symbol, 193 | 'resolution': self.resolution, 194 | 'data': df.to_dict(orient='records') 195 | } 196 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | app/tests/test_api_adapter.py 2 | 3 | .wiki 4 | 5 | # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm 6 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm 7 | 8 | ### PyCharm ### 9 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 10 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 11 | 12 | # User-specific stuff 13 | .idea/**/workspace.xml 14 | .idea/**/tasks.xml 15 | .idea/**/usage.statistics.xml 16 | .idea/**/dictionaries 17 | .idea/**/shelf 18 | 19 | # AWS User-specific 20 | .idea/**/aws.xml 21 | 22 | # Generated files 23 | .idea/**/contentModel.xml 24 | 25 | # Sensitive or high-churn files 26 | .idea/**/dataSources/ 27 | .idea/**/dataSources.ids 28 | .idea/**/dataSources.local.xml 29 | .idea/**/sqlDataSources.xml 30 | .idea/**/dynamic.xml 31 | .idea/**/uiDesigner.xml 32 | .idea/**/dbnavigator.xml 33 | 34 | # Gradle 35 | .idea/**/gradle.xml 36 | .idea/**/libraries 37 | 38 | # Gradle and Maven with auto-import 39 | # When using Gradle or Maven with auto-import, you should exclude module files, 40 | # since they will be recreated, and may cause churn. Uncomment if using 41 | # auto-import. 42 | # .idea/artifacts 43 | # .idea/compiler.xml 44 | # .idea/jarRepositories.xml 45 | # .idea/modules.xml 46 | # .idea/*.iml 47 | # .idea/modules 48 | # *.iml 49 | # *.ipr 50 | 51 | # CMake 52 | cmake-build-*/ 53 | 54 | # Mongo Explorer plugin 55 | .idea/**/mongoSettings.xml 56 | 57 | # File-based project format 58 | *.iws 59 | 60 | # IntelliJ 61 | out/ 62 | 63 | # mpeltonen/sbt-idea plugin 64 | .idea_modules/ 65 | 66 | # JIRA plugin 67 | atlassian-ide-plugin.xml 68 | 69 | # Cursive Clojure plugin 70 | .idea/replstate.xml 71 | 72 | # SonarLint plugin 73 | .idea/sonarlint/ 74 | 75 | # Crashlytics plugin (for Android Studio and IntelliJ) 76 | com_crashlytics_export_strings.xml 77 | crashlytics.properties 78 | crashlytics-build.properties 79 | fabric.properties 80 | 81 | # Editor-based Rest Client 82 | .idea/httpRequests 83 | 84 | # Android studio 3.1+ serialized cache file 85 | .idea/caches/build_file_checksums.ser 86 | 87 | ### PyCharm Patch ### 88 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 89 | 90 | # *.iml 91 | # modules.xml 92 | # .idea/misc.xml 93 | # *.ipr 94 | 95 | # Sonarlint plugin 96 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 97 | .idea/**/sonarlint/ 98 | 99 | # SonarQube Plugin 100 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 101 | .idea/**/sonarIssues.xml 102 | 103 | # Markdown Navigator plugin 104 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 105 | .idea/**/markdown-navigator.xml 106 | .idea/**/markdown-navigator-enh.xml 107 | .idea/**/markdown-navigator/ 108 | 109 | # Cache file creation bug 110 | # See https://youtrack.jetbrains.com/issue/JBR-2257 111 | .idea/$CACHE_FILE$ 112 | 113 | # CodeStream plugin 114 | # https://plugins.jetbrains.com/plugin/12206-codestream 115 | .idea/codestream.xml 116 | 117 | # Azure Toolkit for IntelliJ plugin 118 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij 119 | .idea/**/azureSettings.xml 120 | 121 | ### Python ### 122 | # Byte-compiled / optimized / DLL files 123 | __pycache__/ 124 | *.py[cod] 125 | *$py.class 126 | 127 | # C extensions 128 | *.so 129 | 130 | # Distribution / packaging 131 | .Python 132 | build/ 133 | develop-eggs/ 134 | dist/ 135 | downloads/ 136 | eggs/ 137 | .eggs/ 138 | lib/ 139 | lib64/ 140 | parts/ 141 | sdist/ 142 | var/ 143 | wheels/ 144 | share/python-wheels/ 145 | *.egg-info/ 146 | .installed.cfg 147 | *.egg 148 | MANIFEST 149 | 150 | # PyInstaller 151 | # Usually these files are written by a python script from a template 152 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 153 | *.manifest 154 | *.spec 155 | 156 | # Installer logs 157 | pip-log.txt 158 | pip-delete-this-directory.txt 159 | 160 | # Unit test / coverage reports 161 | htmlcov/ 162 | .tox/ 163 | .nox/ 164 | .coverage 165 | .coverage.* 166 | .cache 167 | nosetests.xml 168 | coverage.xml 169 | *.cover 170 | *.py,cover 171 | .hypothesis/ 172 | .pytest_cache/ 173 | cover/ 174 | 175 | # Translations 176 | *.mo 177 | *.pot 178 | 179 | # Django stuff: 180 | *.log 181 | local_settings.py 182 | db.sqlite3 183 | db.sqlite3-journal 184 | 185 | # Flask stuff: 186 | instance/ 187 | .webassets-cache 188 | 189 | # Scrapy stuff: 190 | .scrapy 191 | 192 | # Sphinx documentation 193 | docs/_build/ 194 | 195 | # PyBuilder 196 | .pybuilder/ 197 | target/ 198 | 199 | # Jupyter Notebook 200 | .ipynb_checkpoints 201 | 202 | # IPython 203 | profile_default/ 204 | ipython_config.py 205 | 206 | # pyenv 207 | # For a library or package, you might want to ignore these files since the code is 208 | # intended to run in multiple environments; otherwise, check them in: 209 | # .python-version 210 | 211 | # pipenv 212 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 213 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 214 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 215 | # install all needed dependencies. 216 | #Pipfile.lock 217 | 218 | # poetry 219 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 220 | # This is especially recommended for binary packages to ensure reproducibility, and is more 221 | # commonly ignored for libraries. 222 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 223 | #poetry.lock 224 | 225 | # pdm 226 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 227 | #pdm.lock 228 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 229 | # in version control. 230 | # https://pdm.fming.dev/#use-with-ide 231 | .pdm.toml 232 | 233 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 234 | __pypackages__/ 235 | 236 | # Celery stuff 237 | celerybeat-schedule 238 | celerybeat.pid 239 | 240 | # SageMath parsed files 241 | *.sage.py 242 | 243 | # Environments 244 | .env 245 | .venv 246 | env/ 247 | venv/ 248 | ENV/ 249 | env.bak/ 250 | venv.bak/ 251 | 252 | # Spyder project settings 253 | .spyderproject 254 | .spyproject 255 | 256 | # Rope project settings 257 | .ropeproject 258 | 259 | # mkdocs documentation 260 | /site 261 | 262 | # mypy 263 | .mypy_cache/ 264 | .dmypy.json 265 | dmypy.json 266 | 267 | # Pyre type checker 268 | .pyre/ 269 | 270 | # pytype static type analyzer 271 | .pytype/ 272 | 273 | # Cython debug symbols 274 | cython_debug/ 275 | 276 | # PyCharm 277 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 278 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 279 | # and can be added to the global gitignore or merged into this file. For a more nuclear 280 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 281 | #.idea/ 282 | 283 | ### Python Patch ### 284 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 285 | poetry.toml 286 | 287 | # ruff 288 | .ruff_cache/ 289 | 290 | # LSP config files 291 | pyrightconfig.json 292 | 293 | # End of https://www.toptal.com/developers/gitignore/api/python,pycharm 294 | -------------------------------------------------------------------------------- /app/tests/test_backtest.py: -------------------------------------------------------------------------------- 1 | from components.backtest.backtest import Backtest 2 | from components.ohlc import CSVAdapter 3 | from components.orders import Order, LimitOrder, StopOrder 4 | from components.orders.enums import OrderSide 5 | from components.positions import Position 6 | from storage.strategies.examples.sma_cross_over import SMACrossOver 7 | 8 | OHLC = CSVAdapter().get_data(path='tests/data/AAPL.csv', symbol='AAPL') 9 | STRATEGY = SMACrossOver(data=OHLC) 10 | 11 | 12 | class TestBacktest: 13 | def test_backtest_orders(self): 14 | # add orders 15 | STRATEGY.orders.market_order(side='buy', quantity=1) 16 | 17 | backtest = Backtest(strategy=STRATEGY, data=OHLC) 18 | backtest.test() 19 | 20 | assert backtest.result is not None 21 | 22 | def test_backtest_positions(self): 23 | QTY = 1 24 | 25 | # add positions 26 | STRATEGY.positions.open(order_type='market', side='buy', quantity=QTY) 27 | entry_price, opened_timestamp = STRATEGY.data.close, (STRATEGY.data.timestamp + 300000) # add 5 minutes 28 | # timestamp 29 | STRATEGY.data.advance_index(100) 30 | STRATEGY.positions.close() 31 | exit_price, closed_timestamp = STRATEGY.data.close, STRATEGY.data.timestamp 32 | 33 | # create and run backtest 34 | backtest = Backtest(strategy=STRATEGY, data=OHLC) 35 | backtest.test() 36 | 37 | # grab the newly tested position 38 | p = STRATEGY.positions.all()[0] 39 | 40 | # calculate expected profit 41 | expected_profit = (exit_price - entry_price) * QTY 42 | 43 | # check position 44 | assert p.size == 0 45 | assert p.pnl == expected_profit 46 | assert p.unrealized_pnl is None 47 | assert p.average_entry_price == entry_price 48 | assert p.average_exit_price == exit_price 49 | assert p.opened_timestamp == opened_timestamp 50 | assert p.closed_timestamp == closed_timestamp 51 | assert backtest.result is not None 52 | 53 | def test_bracket_position(self): 54 | STRATEGY.data.reset_index() 55 | STRATEGY.data.advance_index(100) 56 | STRATEGY.orders.orders = [] 57 | root_order = STRATEGY.orders.market_order(side='buy', quantity=1) 58 | STRATEGY.orders.limit_order(side='sell', quantity=1, price=root_order.price + 2) 59 | STRATEGY.orders.stop_loss_order(side='sell', quantity=1, price=root_order.price - 10) 60 | 61 | # create position 62 | p = Position(orders=STRATEGY.orders.all()) 63 | p.test(ohlc=OHLC) 64 | 65 | # check position 66 | assert p.size == 0 67 | assert p.pnl == 1.9999999999999716 68 | 69 | def test_stop_loss(self): 70 | strategy = SMACrossOver(data=OHLC) 71 | strategy.orders.orders = [] 72 | strategy.data.reset_index() 73 | strategy.data.advance_index(100) 74 | 75 | # add orders 76 | root_order = strategy.orders.market_order(side='buy', quantity=1) 77 | strategy.orders.stop_loss_order(side='sell', quantity=1, price=root_order.price - 5) 78 | strategy.orders.limit_order(side='sell', quantity=1, price=root_order.price + 100) 79 | 80 | # create position 81 | p = Position(orders=strategy.orders.all()) 82 | p.test(ohlc=OHLC) 83 | 84 | # check position 85 | assert p.size == 0 86 | assert p.pnl == -5 87 | 88 | def test_overview_long_orders(self): 89 | strategy = SMACrossOver(data=OHLC) 90 | strategy.orders.orders = [] 91 | strategy.data.reset_index() 92 | strategy.data.advance_index(100) 93 | 94 | # create positions 95 | for i in range(3): 96 | strategy.data.advance_index(5) 97 | strategy.positions.open(order_type='market', side='buy', quantity=1) 98 | strategy.data.advance_index(5) 99 | strategy.positions.close() 100 | 101 | # create backtest 102 | backtest = Backtest(strategy=strategy, data=OHLC) 103 | backtest.test() 104 | 105 | # check overview 106 | assert backtest.result.pnl == -5.170000000000016 107 | 108 | def test_overview_short_orders(self): 109 | # now do the same with short orders 110 | strategy = SMACrossOver(data=OHLC) 111 | strategy.orders.orders = [] 112 | strategy.data.reset_index() 113 | strategy.data.advance_index(100) 114 | 115 | # create positions 116 | for i in range(3): 117 | strategy.data.advance_index(5) 118 | strategy.positions.open(order_type='market', side='sell', quantity=1) 119 | strategy.data.advance_index(5) 120 | strategy.positions.close() 121 | 122 | # create backtest 123 | backtest = Backtest(strategy=strategy, data=OHLC) 124 | backtest.test() 125 | 126 | # check overview 127 | assert backtest.result.pnl == 5.170000000000016 128 | 129 | def test_complex_positions(self): 130 | b = Backtest(strategy=SMACrossOver(), data=OHLC) 131 | 132 | p1 = Position( 133 | orders=[ 134 | Order(side='buy', symbol='AAPL', qty=1, order_type='market', filled_avg_price=100, timestamp=1), 135 | Order(side='sell', symbol='AAPL', qty=1, order_type='limit', filled_avg_price=150, timestamp=2), 136 | ] 137 | ) 138 | 139 | p2 = Position( 140 | orders=[ 141 | Order(side='sell', symbol='AAPL', qty=1, order_type='market', filled_avg_price=100, timestamp=10), 142 | Order(side='buy', symbol='AAPL', qty=1, order_type='stop', filled_avg_price=90, timestamp=20), 143 | ] 144 | ) 145 | 146 | p3 = Position( 147 | orders=[ 148 | Order(side='buy', symbol='AAPL', qty=1, order_type='market', filled_avg_price=100, timestamp=100), 149 | Order(side='sell', symbol='AAPL', qty=1, order_type='stop', filled_avg_price=110, timestamp=200), 150 | ] 151 | ) 152 | 153 | b.strategy.positions.add(p1) 154 | b.strategy.positions.add(p2) 155 | b.strategy.positions.add(p3) 156 | b.test() 157 | 158 | assert b.result.pnl == 70 159 | 160 | def test_backtest_short_with_order_types(self): 161 | 162 | b = Backtest(strategy=SMACrossOver(), data=OHLC) 163 | 164 | side = OrderSide.SELL 165 | open_order = Order( 166 | type='market', 167 | side=side, 168 | qty=100, 169 | symbol='AAPL', 170 | filled_avg_price=257.33, 171 | timestamp=1653984000000, 172 | filled_timestamp=1653984000000, 173 | ) 174 | take_profit = LimitOrder( 175 | side=OrderSide.inverse(side), 176 | qty=100, 177 | symbol='AAPL', 178 | limit_price=256, 179 | ) 180 | stop_loss = StopOrder( 181 | side=OrderSide.inverse(side), 182 | qty=100, 183 | symbol='AAPL', 184 | stop_price=300 185 | ) 186 | 187 | p = Position(orders=[open_order, take_profit, stop_loss]) 188 | b.strategy.positions.add(p) 189 | b.test() 190 | tested_position = b.strategy.positions.all()[0] 191 | print(tested_position) 192 | 193 | assert tested_position.size == 0 194 | assert tested_position.closed_timestamp == 1653984600000 195 | 196 | def test_backtest_long_with_order_types(self): 197 | b = Backtest(strategy=SMACrossOver(), data=OHLC) 198 | side = OrderSide.BUY 199 | open_order = Order( 200 | type='market', 201 | side=side, 202 | qty=100, 203 | symbol='AAPL', 204 | filled_avg_price=257.33, 205 | timestamp=1653984000000, 206 | filled_timestamp=1653984000000, 207 | ) 208 | take_profit = LimitOrder( 209 | side=OrderSide.inverse(side), 210 | qty=100, 211 | symbol='AAPL', 212 | limit_price=260, 213 | ) 214 | stop_loss = StopOrder( 215 | side=OrderSide.inverse(side), 216 | qty=100, 217 | symbol='AAPL', 218 | stop_price=100 219 | ) 220 | 221 | p = Position(orders=[open_order, take_profit, stop_loss]) 222 | b.strategy.positions.add(p) 223 | b.test() 224 | tested_position = b.strategy.positions.all()[0] 225 | assert tested_position.size == 0 226 | assert tested_position.closed_timestamp == 1654183500000 227 | -------------------------------------------------------------------------------- /app/components/positions/positions.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from typing import Optional, List, Union 3 | 4 | import pandas as pd 5 | from loguru import logger 6 | from pydantic import BaseModel 7 | 8 | from components.orders.order import Order, StopOrder, LimitOrder 9 | from components.positions.enums import PositionEffect 10 | from components.positions.exceptions import PositionUnbalancedException, PositionClosedException 11 | from components.positions.utils import get_effect 12 | 13 | 14 | def binary_search(df, target, is_buy): 15 | low, high = 0, len(df) - 1 16 | 17 | while low <= high: 18 | mid = (low + high) // 2 19 | mid_value = df.iloc[mid] 20 | 21 | if (is_buy and mid_value.low <= target) or (not is_buy and mid_value.high >= target): 22 | if mid == 0 or (is_buy and df.iloc[mid - 1].low > target) or ( 23 | not is_buy and df.iloc[mid - 1].high < target): 24 | return mid 25 | high = mid - 1 26 | else: 27 | low = mid + 1 28 | 29 | return None 30 | 31 | 32 | class Position(BaseModel): 33 | id: Optional[str] = None 34 | orders: List[Order] = [] 35 | closed: bool = False 36 | cost_basis: Optional[float] = 0 37 | average_entry_price: Optional[float] = None 38 | average_exit_price: Optional[float] = None 39 | size: Optional[int] = 0 40 | largest_size: Optional[int] = 0 41 | side: Optional[str] = None 42 | unrealized_pnl: Optional[float] = None 43 | pnl: Optional[float] = 0 44 | opened_timestamp: Optional[int] = None 45 | closed_timestamp: Optional[int] = None 46 | is_tested: bool = False 47 | 48 | def __int__(self): 49 | super().__init__() 50 | self.id = self._get_id() 51 | 52 | def __str__(self): 53 | return f'Position: {""}\t[{self.side.upper()}]\t{self.size} {self.average_entry_price} -> ' \ 54 | f'{self.average_exit_price} \t pnl:{round(self.pnl, 2)} (opn:' \ 55 | f' {self.opened_timestamp}, cls: {self.closed_timestamp})' 56 | 57 | def get_size(self): 58 | """Get the size of all orders in the position.""" 59 | return sum([o.qty for o in self.orders]) 60 | 61 | def get_side(self): 62 | """Returns the side of the position, either buy or sell.""" 63 | if self.side is None: 64 | return self.orders[0].side 65 | return self.side 66 | 67 | def _get_id(self): 68 | """Get the id of the position.""" 69 | order_ids = [o.id for o in self.orders] 70 | return hashlib.md5(str(order_ids).encode()).hexdigest() 71 | 72 | def _get_root_side_orders(self): 73 | """Get the root side orders of the position.""" 74 | return [o for o in self.orders if o.side == self.side] 75 | 76 | def _update_average_entry(self, order: Order): 77 | if self.average_exit_price: 78 | self.average_exit_price = ((self.average_exit_price + order.filled_avg_price) / 2) 79 | else: 80 | self.average_exit_price = order.filled_avg_price 81 | 82 | def _update_pnl(self, order: Order): 83 | difference = order.filled_avg_price - self.average_entry_price 84 | multiplier = -1 if self.get_side() == 'sell' else 1 85 | realized_pnl = (difference * abs(order.qty)) * multiplier 86 | self.pnl += realized_pnl 87 | 88 | def _update_largest_size(self): 89 | if abs(self.size) > self.largest_size: 90 | self.largest_size = abs(self.size) 91 | 92 | def _update_opened_timestamp(self, order: Order): 93 | if self.opened_timestamp is None: 94 | self.opened_timestamp = order.timestamp 95 | 96 | def _add_order_to_size(self, order: Order): 97 | self.size += order.qty 98 | 99 | def _fill_order(self, order: Union[Order, StopOrder, LimitOrder], ohlc: 'OHLC' = None): 100 | """Handles TBD orders. Sets the timestamp and fills the order if it was filled.""" 101 | start_index = ohlc.index.get_loc(self.orders[0].timestamp) 102 | df = ohlc.dataframe.iloc[start_index + 1:] 103 | 104 | if order.type == 'stop': 105 | filled_order = self._process_stop_order(order, df) 106 | elif order.type == 'limit': 107 | filled_order = self._process_limit_order(order, df) 108 | 109 | if filled_order is None: 110 | self._handle_unfilled_order(order) 111 | else: 112 | order.timestamp = order.filled_timestamp 113 | 114 | def _process_stop_order(self, order, df: pd.DataFrame): 115 | condition = (df.low <= order.stop_price) if self.side == 'buy' else (df.high >= order.stop_price) 116 | filtered_df = df[condition] 117 | 118 | if not filtered_df.empty: 119 | order.filled_timestamp = filtered_df.index[0] 120 | order.filled_avg_price = order.stop_price 121 | return order 122 | return None 123 | 124 | def _process_limit_order(self, order, df: pd.DataFrame): 125 | condition = (df.high >= order.limit_price) if self.side == 'buy' else (df.low <= order.limit_price) 126 | filtered_df = df[condition] 127 | 128 | if not filtered_df.empty: 129 | order.filled_timestamp = filtered_df.index[0] 130 | order.filled_avg_price = order.limit_price 131 | return order 132 | return None 133 | 134 | def _handle_unfilled_order(self, order): 135 | logger.warning(f'Order {order.get_id()} was never filled') 136 | logger.warning(f'{self}') 137 | for order in self.orders: 138 | logger.warning(f'\t{order}') 139 | # set order did not fill to true 140 | order.did_not_fill = True 141 | return 142 | 143 | def handle_order(self, order: Union[Order, StopOrder, LimitOrder], ohlc: 'OHLC' = None): 144 | 145 | # if order isn't already filled, it cannot be handled yet 146 | if order.filled_timestamp is None: 147 | return 148 | 149 | # if the position is closed, we can't handle any more orders 150 | if self.closed: 151 | raise PositionClosedException('Position is already closed') 152 | 153 | # if the position is missing a side, set it to the side of the first order 154 | if self.side is None: 155 | self.side = order.side 156 | 157 | # gets the position effect, either add or reduce 158 | effect = get_effect(position=self, order=order) 159 | 160 | if effect == PositionEffect.ADD: 161 | # since the position is added, we need to calculate the cost basis 162 | self.cost_basis += order.filled_avg_price * order.qty 163 | # adjust the average entry price 164 | self.average_entry_price = self.cost_basis / (self.size + order.qty) 165 | 166 | if effect == PositionEffect.REDUCE: 167 | # check that the position will not change sides 168 | if self.side == 'buy' and self.size + order.qty < 0: 169 | raise PositionUnbalancedException('Position will change sides') 170 | if self.side == 'sell' and self.size + order.qty > 0: 171 | raise PositionUnbalancedException('Position will change sides') 172 | 173 | # since the position is reduced, we need to re-calculate the average entry price and update the PNL 174 | self._update_average_entry(order) 175 | self._update_pnl(order) 176 | 177 | # adjust the size, if the size is 0, the position is closed 178 | self._add_order_to_size(order) 179 | self._update_largest_size() 180 | self._update_opened_timestamp(order) 181 | 182 | # set closed, update closed timestamp if closed 183 | self.closed = self.size == 0 184 | if self.closed: 185 | self.closed_timestamp = order.timestamp 186 | 187 | # if order is missing filled timestamp, set it to the order timestamp 188 | if order.filled_timestamp is None: 189 | order.filled_timestamp = order.timestamp 190 | 191 | def test(self, ohlc: 'OHLC' = None): 192 | """Backtest the position.""" 193 | # handle all orders with a filled timestamp, as these are already filled 194 | filled_orders = [o for o in self.orders if o.filled_timestamp is not None] 195 | for order in filled_orders: 196 | self.handle_order(order=order, ohlc=ohlc) 197 | 198 | # if there are still working orders, handle them 199 | working_orders = [o for o in self.orders if o.filled_timestamp is None] 200 | if len(working_orders) > 0: 201 | 202 | # handle all orders without a filled timestamp, as these are TBD 203 | for order in working_orders: 204 | self._fill_order(order=order, ohlc=ohlc) 205 | 206 | # determine which order was filled first, filter out any orders that were never filled 207 | filled_working_orders = [o for o in working_orders if o.filled_timestamp is not None] 208 | sorted_orders = sorted(filled_working_orders, key=lambda o: o.timestamp) 209 | 210 | # handle the first order 211 | first_order = sorted_orders[0] 212 | self.handle_order(order=first_order, ohlc=ohlc) 213 | 214 | # set the remaining order's filled_timestamp to None 215 | for order in sorted_orders[1:]: 216 | order.filled_timestamp = None 217 | 218 | 219 | class BracketPosition(Position): 220 | def __init__(self): 221 | super().__init__() 222 | self.stop_order: Union[StopOrder, None] = None 223 | self.limit_order: Union[LimitOrder, None] = None 224 | --------------------------------------------------------------------------------