├── .dockerignore
├── app
├── api
│ ├── __init__.py
│ └── api_v1
│ │ ├── __init__.py
│ │ ├── endpoints
│ │ ├── __init__.py
│ │ ├── data.py
│ │ └── strategy.py
│ │ └── api.py
├── core
│ ├── __init__.py
│ └── commands
│ │ ├── __init__.py
│ │ ├── command.py
│ │ └── strategy
│ │ ├── __init__.py
│ │ └── commands.py
├── tests
│ ├── __init__.py
│ ├── temp
│ │ └── .keep
│ ├── test_positions.py
│ ├── test_adapter.py
│ ├── test_command.py
│ ├── test_position_testing.py
│ ├── test_strategy.py
│ ├── test_order_manager.py
│ ├── test_ohlc.py
│ ├── test_parameter.py
│ ├── test_signals.py
│ ├── test_order.py
│ └── test_backtest.py
├── utils
│ ├── __init__.py
│ ├── loaders
│ │ ├── __init__.py
│ │ ├── load_all.py
│ │ └── strategy_loader.py
│ ├── commons.py
│ ├── wiki_link.py
│ ├── formatting.py
│ └── create_file.py
├── benchmarks
│ ├── __init__.py
│ ├── profile_fill_orders.py
│ └── profile_strategy_with_builtins.py
├── components
│ ├── backtest
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ └── backtest.py
│ ├── manager
│ │ ├── __init__.py
│ │ └── manager.py
│ ├── ohlc
│ │ ├── data_adapters
│ │ │ ├── utils.py
│ │ │ ├── __init__.py
│ │ │ ├── api_adapter.py
│ │ │ └── adapter.py
│ │ ├── __init__.py
│ │ ├── symbol.py
│ │ └── ohlc.py
│ ├── strategy
│ │ ├── builtins
│ │ │ ├── __init__.py
│ │ │ └── ta
│ │ │ │ ├── sma.py
│ │ │ │ ├── logic.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── kalman_filter.py
│ │ │ │ ├── atr.py
│ │ │ │ └── correlation.py
│ │ ├── templates
│ │ │ ├── __init__.py
│ │ │ └── strategy_template.py
│ │ ├── __init__.py
│ │ ├── decorators.py
│ │ ├── series.py
│ │ └── strategy.py
│ ├── positions
│ │ ├── __init__.py
│ │ ├── enums.py
│ │ ├── exceptions.py
│ │ ├── utils.py
│ │ ├── position_manager.py
│ │ └── positions.py
│ ├── orders
│ │ ├── __init__.py
│ │ ├── enums.py
│ │ ├── signals.py
│ │ ├── order_manager.py
│ │ └── order.py
│ ├── __init__.py
│ └── parameter.py
├── settings.py
├── manage.py
├── Dockerfile
├── main.py
├── storage
│ └── strategies
│ │ └── examples
│ │ ├── ohlc_demo.py
│ │ ├── sma_cross_over.py
│ │ ├── using_builtins.py
│ │ └── sma_cross_over_advanced.py
└── requirements.txt
├── .idea
├── vcs.xml
├── misc.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── modules.xml
└── stratis-v2.iml
├── docker-compose-standalone.yml
├── ui
├── README.md
└── Dockerfile
├── nginx.conf
├── docker-compose.yml
├── .github
└── workflows
│ └── pytest.yml
├── LICENSE.md
├── README.md
└── .gitignore
/.dockerignore:
--------------------------------------------------------------------------------
1 | .venv
--------------------------------------------------------------------------------
/app/api/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/tests/temp/.keep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/api/api_v1/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/benchmarks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/tests/test_positions.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/core/commands/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/utils/loaders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/api/api_v1/endpoints/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/components/backtest/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/components/manager/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/components/ohlc/data_adapters/utils.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/components/strategy/builtins/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/components/strategy/templates/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/components/positions/__init__.py:
--------------------------------------------------------------------------------
1 | from components.positions.positions import Position
--------------------------------------------------------------------------------
/app/components/orders/__init__.py:
--------------------------------------------------------------------------------
1 | from components.orders.order import Order, StopOrder, LimitOrder
--------------------------------------------------------------------------------
/app/utils/commons.py:
--------------------------------------------------------------------------------
1 | STRATEGY_TEMPLATE_PATH = 'components/strategy/templates/strategy_template.py'
--------------------------------------------------------------------------------
/app/components/positions/enums.py:
--------------------------------------------------------------------------------
1 | class PositionEffect:
2 | REDUCE = 'reduce'
3 | ADD = 'add'
4 |
--------------------------------------------------------------------------------
/app/components/ohlc/data_adapters/__init__.py:
--------------------------------------------------------------------------------
1 | from components.ohlc.data_adapters.adapter import DataAdapter, CSVAdapter
--------------------------------------------------------------------------------
/app/components/ohlc/__init__.py:
--------------------------------------------------------------------------------
1 | from components.ohlc.ohlc import OHLC
2 | from components.ohlc.data_adapters import DataAdapter, CSVAdapter
--------------------------------------------------------------------------------
/app/utils/wiki_link.py:
--------------------------------------------------------------------------------
1 | def wiki_link(link: str) -> str:
2 | """Returns a link to the wiki."""
3 | return (
4 | f"\n\n\tSee Wiki: {link}\n\n"
5 | )
--------------------------------------------------------------------------------
/app/components/__init__.py:
--------------------------------------------------------------------------------
1 | from components.parameter import Parameter
2 | from components.strategy import Strategy
3 | from components.strategy.decorators import on_step
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/docker-compose-standalone.yml:
--------------------------------------------------------------------------------
1 | version: '3'
2 |
3 | services:
4 | app:
5 | build:
6 | context: .
7 | dockerfile: app/Dockerfile
8 | ports:
9 | - "8000:8000"
--------------------------------------------------------------------------------
/app/components/strategy/builtins/ta/sma.py:
--------------------------------------------------------------------------------
1 | from components.strategy import Series
2 |
3 |
4 | def sma(data, period) -> Series:
5 | """Simple Moving Average"""
6 | return Series(data.rolling(period).mean())
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/app/settings.py:
--------------------------------------------------------------------------------
1 | from components.ohlc import CSVAdapter
2 | from components.ohlc.data_adapters.api_adapter import APIDataAdapter
3 |
4 | # add any custom data adapters here
5 | DATA_ADAPTERS = [
6 | CSVAdapter,
7 | APIDataAdapter
8 | ]
--------------------------------------------------------------------------------
/app/components/strategy/__init__.py:
--------------------------------------------------------------------------------
1 | from components.strategy.strategy import BaseStrategy as Strategy
2 | from components.strategy.decorators import on_step
3 | from components.strategy.series import Series
4 | from components.strategy.builtins import ta
--------------------------------------------------------------------------------
/app/core/commands/command.py:
--------------------------------------------------------------------------------
1 | from utils.formatting import snake_case, pascal_case
2 |
3 |
4 | class BaseCommand:
5 |
6 | help = 'Base command'
7 |
8 | def handle(self, *args, **kwargs):
9 | raise NotImplementedError
10 |
--------------------------------------------------------------------------------
/app/components/positions/exceptions.py:
--------------------------------------------------------------------------------
1 | class PositionValidationException(Exception):
2 | pass
3 |
4 |
5 | class PositionUnbalancedException(Exception):
6 | pass
7 |
8 |
9 | class PositionClosedException(Exception):
10 | pass
11 |
--------------------------------------------------------------------------------
/ui/README.md:
--------------------------------------------------------------------------------
1 | # Stratis UI Repo
2 |
3 | [https://github.com/robswc/stratis-ui](https://github.com/robswc/stratis-ui)
4 |
5 | This directory simply contains the Dockerfile for the Stratis UI.
6 | The Dockerfile is used to build the image that is used to run the Stratis UI.
--------------------------------------------------------------------------------
/app/manage.py:
--------------------------------------------------------------------------------
1 | """
2 | A command-line utility for managing the app.
3 | """
4 | from core.commands import strategy
5 |
6 | import typer
7 |
8 | app = typer.Typer()
9 |
10 | app.add_typer(strategy.app, name='strategy')
11 |
12 | if __name__ == '__main__':
13 | app()
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/app/components/strategy/builtins/ta/logic.py:
--------------------------------------------------------------------------------
1 |
2 | class Logic:
3 | @staticmethod
4 | def crossover(a: 'Series', b: 'Series') -> bool:
5 | return a > b and a.shift(1) < b.shift(1)
6 |
7 | @staticmethod
8 | def crossunder(a: 'Series', b: 'Series') -> bool:
9 | return a < b and a.shift(1) > b.shift(1)
--------------------------------------------------------------------------------
/app/api/api_v1/api.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter
2 |
3 | from api.api_v1.endpoints import data
4 | from api.api_v1.endpoints import strategy
5 |
6 | api_router = APIRouter()
7 | api_router.include_router(strategy.router, prefix="/strategy", tags=["strategy"])
8 | api_router.include_router(data.router, prefix="/data", tags=["data"])
--------------------------------------------------------------------------------
/app/components/strategy/builtins/ta/__init__.py:
--------------------------------------------------------------------------------
1 | from components.strategy.builtins.ta.sma import *
2 | from components.strategy.builtins.ta.correlation import *
3 | from components.strategy.builtins.ta.kalman_filter import *
4 | from components.strategy.builtins.ta.atr import *
5 | from components.strategy.builtins.ta.logic import Logic as logic
--------------------------------------------------------------------------------
/app/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10
2 |
3 | COPY requirements.txt /tmp/requirements.txt
4 | RUN pip install -r /tmp/requirements.txt
5 |
6 | COPY . /app
7 | WORKDIR /app
8 |
9 | EXPOSE 8000
10 |
11 | # run the FastAPI app with gunicorn, using uvicorn workers
12 | CMD ["gunicorn", "-k", "uvicorn.workers.UvicornWorker", "-b", "0.0.0.0:8000", "main:app"]
--------------------------------------------------------------------------------
/app/benchmarks/profile_fill_orders.py:
--------------------------------------------------------------------------------
1 | from components.ohlc import CSVAdapter
2 | from storage.strategies.examples.sma_cross_over_advanced import SMACrossOverAdvanced
3 | from storage.strategies.examples.using_builtins import UsingBuiltins
4 |
5 | adapter = CSVAdapter()
6 | ohlc = adapter.get_data(path='../tests/data/AAPL.csv')
7 |
8 | strategy = SMACrossOverAdvanced()
9 | strategy.run(ohlc)
10 |
11 | # 600ms
12 |
--------------------------------------------------------------------------------
/nginx.conf:
--------------------------------------------------------------------------------
1 | events {}
2 |
3 | http {
4 | upstream app_backend {
5 | server app:8000;
6 | }
7 |
8 | upstream stratis_next_js {
9 | server stratis-next-js:3000;
10 | }
11 |
12 | server {
13 | location /api {
14 | proxy_pass http://app_backend;
15 | }
16 |
17 | location / {
18 | proxy_pass http://stratis_next_js;
19 | }
20 | }
21 | }
--------------------------------------------------------------------------------
/app/benchmarks/profile_strategy_with_builtins.py:
--------------------------------------------------------------------------------
1 | from components.ohlc import CSVAdapter
2 | from storage.strategies.examples.using_builtins import UsingBuiltins
3 |
4 | adapter = CSVAdapter()
5 | ohlc = adapter.get_data(path='../tests/data/AAPL.csv')
6 |
7 | strategy = UsingBuiltins()
8 | strategy.run(ohlc)
9 |
10 | # 3/29/2023
11 | # first run: 1700ms
12 | # second run: 1400ms
13 | # third run: 1000ms
14 | # fourth run: 800ms
15 | # fifth run: 700ms
16 | # sixth run: sub 700ms
17 |
--------------------------------------------------------------------------------
/app/components/strategy/templates/strategy_template.py:
--------------------------------------------------------------------------------
1 | from components import Strategy, on_step, Parameter
2 |
3 |
4 | class $StrategyName(Strategy):
5 | """Edit your strategy description here."""
6 |
7 | # my_parameter = Parameter(10) # Example parameter, uncomment to use
8 |
9 | def __init__(self, *args, **kwargs):
10 | super().__init__(*args, **kwargs)
11 | # add any pre-compiled data here
12 | # https://github.com/robswc/stratis/wiki/Strategies#init-and-pre-compiled-data-series
13 |
14 | @on_step
15 | def do_stuff(self):
16 | print('Running on each step...')
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3'
2 |
3 | services:
4 | app:
5 | build:
6 | context: ./app
7 | dockerfile: Dockerfile
8 | expose:
9 | - "8000"
10 | ports:
11 | - "8000:8000"
12 |
13 | stratis-next-js:
14 | depends_on:
15 | - app
16 | build:
17 | context: ./ui
18 | dockerfile: Dockerfile
19 | expose:
20 | - "3000"
21 |
22 | nginx:
23 | image: nginx:latest
24 | volumes:
25 | - ./nginx.conf:/etc/nginx/nginx.conf
26 | depends_on:
27 | - app
28 | - stratis-next-js
29 | ports:
30 | - "${NGINX_PORT:-3000}:80"
--------------------------------------------------------------------------------
/ui/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM node:18-alpine AS base
2 |
3 | # Install git
4 | RUN apk add --no-cache git
5 |
6 | # Clone from GitHub
7 | ADD https://api.github.com/repos/robswc/stratis-ui latest_commit
8 | RUN git clone https://github.com/robswc/stratis-ui.git /app
9 |
10 | # Set working directory
11 | WORKDIR /app
12 |
13 | # Install dependencies
14 | RUN npm install
15 |
16 | # Add args
17 | ARG MAX_OLD_SPACE_SIZE=1024
18 |
19 | # Build NextJS app
20 | RUN npm run build --max-old-space-size=${MAX_OLD_SPACE_SIZE}
21 |
22 | # Expose port 3000
23 | EXPOSE 3000
24 |
25 | # Run NextJS app
26 | CMD ["npm", "run", "start"]
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/app/main.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI
2 | from starlette.middleware.cors import CORSMiddleware
3 |
4 | from api.api_v1.api import api_router
5 | from utils.loaders import load_all # noqa: F401
6 |
7 | app = FastAPI(
8 | title="Stratis API",
9 | )
10 |
11 | allowed_origins = [
12 | "http://localhost",
13 | "http://localhost:8000",
14 | "http://localhost:3000",
15 | ]
16 |
17 | # handle CORS
18 | app.add_middleware(
19 | CORSMiddleware,
20 | allow_origins=allowed_origins,
21 | allow_credentials=True,
22 | allow_methods=["*"],
23 | allow_headers=["*"],
24 | )
25 |
26 | app.include_router(api_router, prefix="/api/v1")
--------------------------------------------------------------------------------
/app/core/commands/strategy/__init__.py:
--------------------------------------------------------------------------------
1 | import typer
2 |
3 | from core.commands.strategy.commands import CreateNewStrategy, ListStrategies
4 |
5 | app = typer.Typer(
6 | help="Manage strategies",
7 | )
8 |
9 |
10 | @app.command()
11 | def create(name: str = typer.Argument(..., help=CreateNewStrategy.help)):
12 | """
13 | Create a new strategy
14 | """
15 | typer.echo(f"Creating strategy: {name}")
16 | CreateNewStrategy(strategy_name=name).handle()
17 |
18 | @app.command(name="list")
19 | def list_strategies():
20 | """
21 | List all strategies
22 | """
23 | typer.echo("Listing strategies")
24 | ListStrategies().handle()
--------------------------------------------------------------------------------
/app/storage/strategies/examples/ohlc_demo.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from components import Strategy, on_step
3 |
4 |
5 | class OHLCDemo(Strategy):
6 |
7 | @on_step
8 | def print_ohlc(self):
9 |
10 | # strategy shorthands for the OHLC data
11 | timestamp = self.data.timestamp
12 | close = self.data.close
13 |
14 | # if the timestamp is a multiple of 3600000 (1 hour)
15 | if timestamp % 3600000 == 0:
16 | # strategy a datetime object from the timestamp
17 | dt = datetime.datetime.fromtimestamp(timestamp / 1000)
18 | if dt.hour == 10:
19 | print(f'{dt}: {close}')
20 |
21 |
22 |
--------------------------------------------------------------------------------
/app/components/ohlc/symbol.py:
--------------------------------------------------------------------------------
1 | class Symbol:
2 | """
3 | Represents a tradeable instrument. Feel free to extend this class to fit your needs.
4 | """
5 |
6 | def __init__(self, symbol: str):
7 | if type(symbol) != str:
8 | raise TypeError('Symbol must be a string.')
9 | self.symbol = symbol
10 |
11 | def __str__(self):
12 | return self.symbol
13 |
14 |
15 | class Equity(Symbol):
16 | """
17 | Represents an equity instrument.
18 | """
19 |
20 | def __init__(self, symbol: str):
21 | super().__init__(symbol)
22 | self.cusip = None
23 | self.description = ''
24 | self.exchange = None
25 | self.type = None
26 |
--------------------------------------------------------------------------------
/app/tests/test_adapter.py:
--------------------------------------------------------------------------------
1 | class TestDataAdapters:
2 | def test_init_data_adapter(self):
3 | from components.ohlc import DataAdapter
4 | adapter = DataAdapter()
5 | assert adapter.name == 'DataAdapter'
6 |
7 | # test csv adapter
8 | from components.ohlc import CSVAdapter
9 | csv_adapter = CSVAdapter()
10 | assert csv_adapter.name == 'CSVAdapter'
11 |
12 | def test_csv_adapter(self):
13 | from components.ohlc import CSVAdapter
14 | csv_adapter = CSVAdapter()
15 | ohlc = csv_adapter.get_data(None, None, 'tests/data/AAPL.csv', 'AAPL')
16 | assert str(ohlc.symbol) == 'AAPL'
17 | assert ohlc.dataframe.shape == (5001, 5)
18 |
--------------------------------------------------------------------------------
/app/components/manager/manager.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 |
3 |
4 | class ComponentManager:
5 |
6 | _components = []
7 |
8 | @classmethod
9 | def register(cls, component):
10 | if component in cls._components:
11 | return
12 | cls._components.append(component)
13 | logger.debug(f'Registered component {component} ({component.__module__})')
14 |
15 | @classmethod
16 | def all(cls):
17 | return [o() for o in cls._components]
18 |
19 | @classmethod
20 | def get(cls, name):
21 | for component in cls._components:
22 | if component.__name__ == name:
23 | return component()
24 | raise ValueError(f'Component {name} not found.')
--------------------------------------------------------------------------------
/app/components/strategy/builtins/ta/kalman_filter.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import pandas as pd
4 | from loguru import logger
5 |
6 | from components.strategy import Series
7 | import numpy as np
8 |
9 |
10 | def kalman_filter(src: pd.Series, gain):
11 | src = np.array(src)
12 | n = len(src)
13 | kf = np.zeros(n)
14 | velo = np.zeros(n)
15 | smooth = np.zeros(n)
16 | gain_sqrt = np.sqrt(gain / 5000)
17 |
18 | for i in range(n):
19 | if i > 0:
20 | dk = src[i] - kf[i - 1]
21 | else:
22 | dk = src[i]
23 | smooth[i] = kf[i - 1] + dk * gain_sqrt if i > 0 else src[i]
24 | velo[i] = velo[i - 1] + (gain / 10000) * dk
25 | kf[i] = smooth[i] + velo[i]
26 | return Series(list(kf))
27 |
--------------------------------------------------------------------------------
/.github/workflows/pytest.yml:
--------------------------------------------------------------------------------
1 | name: Python package
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 | strategy:
10 | matrix:
11 | python-version: ["3.8", "3.9", "3.10"]
12 |
13 | steps:
14 | - uses: actions/checkout@v3
15 | - name: Set up Python ${{ matrix.python-version }}
16 | uses: actions/setup-python@v4
17 | with:
18 | python-version: ${{ matrix.python-version }}
19 | - name: Install dependencies
20 | working-directory: ./app
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install pytest
24 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
25 | - name: Test with pytest
26 | working-directory: ./app
27 | run: |
28 | pytest --rootdir=./tests
29 |
--------------------------------------------------------------------------------
/app/utils/formatting.py:
--------------------------------------------------------------------------------
1 |
2 | def camel_case(value):
3 | """Converts a string to camel case."""
4 | return ''.join(x.capitalize() or '_' for x in value.split('_'))
5 |
6 | def snake_case(value):
7 | """Converts a string to snake case."""
8 | return ''.join(x.lower() if x.islower() else '_' + x.lower() for x in value)
9 |
10 | def pascal_case(value):
11 | """Converts a string to pascal case."""
12 | return ''.join(x.capitalize() for x in value.split('_'))
13 |
14 | def pascal_to_snake_case(value: str):
15 | """Converts a string from pascal case to snake case."""
16 | case = ''
17 | for i, c in enumerate(value):
18 | if c.isupper() and i != 0:
19 | case += '_'
20 | case += c.lower()
21 | return case
22 |
23 | def kebab_case(value):
24 | """Converts a string to kebab case."""
25 | return '-'.join(x.lower() for x in value.split('_'))
--------------------------------------------------------------------------------
/app/tests/test_command.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from core.commands.strategy import CreateNewStrategy
4 |
5 |
6 | class TestCommands:
7 | def test_create_strategy(self):
8 |
9 | # Create a new strategy
10 | test_path = 'tests/temp'
11 | CreateNewStrategy(strategy_name='MyTestStrategy').handle(override_path=test_path, prompt=False)
12 | assert os.path.exists(f'{test_path}/my_test_strategy.py')
13 | os.remove(f'{test_path}/my_test_strategy.py')
14 |
15 | # test that exception is raised when strategy name is invalid
16 | try:
17 | CreateNewStrategy(strategy_name='MyTestStrategy').handle(override_path=test_path)
18 | except ValueError:
19 | assert True
20 |
21 | try:
22 | CreateNewStrategy(strategy_name='My Test Strategy').handle(override_path=test_path)
23 | except ValueError:
24 | assert True
--------------------------------------------------------------------------------
/app/requirements.txt:
--------------------------------------------------------------------------------
1 | anyio==3.6.1
2 | attrs==22.1.0
3 | certifi==2022.9.24
4 | charset-normalizer==2.1.1
5 | click==8.1.3
6 | colorama==0.4.6
7 | commonmark==0.9.1
8 | fastapi==0.85.0
9 | gunicorn==20.1.0
10 | h11==0.14.0
11 | httptools==0.5.0
12 | idna==3.4
13 | importlib-metadata==6.0.0
14 | iniconfig==1.1.1
15 | Jinja2==3.1.2
16 | llvmlite==0.39.1
17 | loguru==0.6.0
18 | MarkupSafe==2.1.1
19 | numba==0.56.4
20 | numpy==1.23.4
21 | packaging==21.3
22 | pandas==1.5.1
23 | pluggy==1.0.0
24 | py==1.11.0
25 | pydantic==1.10.2
26 | Pygments==2.14.0
27 | pyparsing==3.0.9
28 | pytest==7.1.3
29 | python-dateutil==2.8.2
30 | python-dotenv==0.21.0
31 | pytz==2022.5
32 | PyYAML==6.0
33 | requests==2.28.1
34 | rich==12.6.0
35 | shellingham==1.5.0.post1
36 | six==1.16.0
37 | sniffio==1.3.0
38 | starlette==0.20.4
39 | tomli==2.0.1
40 | typer==0.7.0
41 | typing_extensions==4.3.0
42 | urllib3==2.6.0
43 | uvicorn==0.18.3
44 | uvloop==0.17.0
45 | watchfiles==0.17.0
46 | websockets==10.3
47 | zipp==3.11.0
48 |
--------------------------------------------------------------------------------
/.idea/stratis-v2.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/app/components/strategy/builtins/ta/atr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from components.strategy import Series
4 |
5 |
6 | def atr(high: Series, low: Series, close: Series, period: int = 12) -> Series:
7 | """ Calculate the average true range """
8 |
9 | high = np.array(high)
10 | low = np.array(low)
11 | close = np.array(close)
12 |
13 | if len(high) != len(low) != len(close):
14 | raise ValueError("Input lists must have the same length")
15 |
16 | if len(high) < period:
17 | raise ValueError("Input lists must have at least 'period' number of elements")
18 |
19 | true_range = []
20 |
21 | for i in range(1, len(high)):
22 | tr = max(high[i] - low[i], abs(high[i] - close[i - 1]), abs(low[i] - close[i - 1]))
23 | true_range.append(tr)
24 |
25 | result = []
26 |
27 | for i in range(period, len(true_range) + 1):
28 | average = np.mean(true_range[i - period: i])
29 | result.append(average)
30 |
31 | return Series(result)
32 |
--------------------------------------------------------------------------------
/app/components/backtest/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from components.positions import Position
4 |
5 |
6 | def remove_overlapping_positions(positions: List[Position], max_overlap: int = 1):
7 | """Remove overlapping positions from a list of positions."""
8 |
9 | new_positions = []
10 |
11 | # Sort the positions by their opened timestamp
12 | sorted_positions = sorted(positions, key=lambda x: x.opened_timestamp)
13 |
14 | i = 0
15 | last_timestamp = None
16 | while i < len(sorted_positions):
17 | p = sorted_positions[i]
18 | if last_timestamp is None:
19 | last_timestamp = p.closed_timestamp
20 | new_positions.append(p)
21 | continue
22 |
23 | no_overlap = last_timestamp < p.opened_timestamp
24 | if no_overlap:
25 | new_positions.append(p)
26 | last_timestamp = p.closed_timestamp
27 |
28 | i += 1
29 |
30 | # Return the modified copy of the list
31 | return new_positions
32 |
--------------------------------------------------------------------------------
/app/utils/create_file.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def create_file(path: str, contents: str, prompt: bool = False) -> tuple:
5 | """Creates a file at the specified path with the specified contents."""
6 |
7 | # first check that the file doesn't already exist
8 | if prompt:
9 | if os.path.exists(path):
10 | # if the file exists, prompt the user to overwrite it
11 | overwrite = input(f'File already exists at {path}. Overwrite? (y/n): ')
12 | if overwrite.lower() == 'y':
13 | # if the user wants to overwrite it, delete the file
14 | os.remove(path)
15 | else:
16 | print('File not overwritten.')
17 | return path, False
18 |
19 | # strategy the file
20 | try:
21 | with open(path, 'w') as f:
22 | f.write(contents)
23 | f.close()
24 | except Exception as e:
25 | print(f'Error creating file: {e}')
26 | return path, False
27 | return path, True
28 |
--------------------------------------------------------------------------------
/app/tests/test_position_testing.py:
--------------------------------------------------------------------------------
1 | from components.backtest.backtest import Backtest
2 | from components.ohlc import CSVAdapter
3 | from storage.strategies.examples.sma_cross_over import SMACrossOver
4 |
5 |
6 | class TestPositions:
7 | def test_large_amount_of_positions(self):
8 | OHLC = CSVAdapter().get_data(start=None, end=None, path='tests/data/AAPL.csv', symbol='AAPL')
9 | strategy = SMACrossOver(data=OHLC)
10 | strategy.data.advance_index(100)
11 |
12 | # create positions
13 | for i in range(500):
14 | strategy.data.advance_index(5)
15 | strategy.positions.open(order_type='market', side='buy', quantity=1)
16 | strategy.data.advance_index(2)
17 | strategy.positions.close()
18 |
19 | # create backtest
20 | backtest = Backtest(strategy=strategy, data=OHLC)
21 | backtest.test()
22 |
23 | # check overview
24 | print(backtest.result.get_overview())
25 |
26 | assert backtest.result.get_overview().get('trades') == 500
--------------------------------------------------------------------------------
/app/components/strategy/decorators.py:
--------------------------------------------------------------------------------
1 | import inspect
2 |
3 |
4 | def extract_decorators(cls):
5 | befores = []
6 | steps = []
7 | afters = []
8 | methods = [method for method in dir(cls) if callable(getattr(cls, method)) and not method.startswith("__")]
9 | for method in methods:
10 | for text, member in inspect.getmembers(getattr(cls, method)):
11 | if text == '__func__':
12 | if 'before' in str(member):
13 | befores.append(method)
14 | if 'step' in str(member):
15 | steps.append(method)
16 | if 'after' in str(member):
17 | afters.append(method)
18 | return befores, steps, afters
19 |
20 |
21 | def on_step(func):
22 | def wrapper(*args, **kwargs):
23 | return func(*args, **kwargs)
24 | return wrapper
25 |
26 | def before(func):
27 | def wrapper(*args, **kwargs):
28 | return func(*args, **kwargs)
29 | return wrapper
30 |
31 | def after(func):
32 | def wrapper(*args, **kwargs):
33 | return func(*args, **kwargs)
34 | return wrapper
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Robert S.W. Carroll
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/app/components/orders/enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class TimeInForce(str, Enum):
5 | DAY = "day"
6 | GTC = "gtc"
7 | OPG = "opg"
8 | CLS = "cls"
9 | IOC = "ioc"
10 | FOK = "fok"
11 |
12 |
13 | class OrderType(str, Enum):
14 | MARKET = "market"
15 | LIMIT = "limit"
16 | STOP = "stop"
17 | STOP_LIMIT = "stop_limit"
18 | TRAILING_STOP = "trailing_stop"
19 |
20 | @staticmethod
21 | def abbreviation(order_type):
22 | if order_type == OrderType.MARKET:
23 | return "mkt"
24 | elif order_type == OrderType.LIMIT:
25 | return "lmt"
26 | elif order_type == OrderType.STOP:
27 | return "stp"
28 | elif order_type == OrderType.STOP_LIMIT:
29 | return "stpl"
30 | elif order_type == OrderType.TRAILING_STOP:
31 | return "trsl"
32 |
33 |
34 | class OrderSide(str, Enum):
35 | BUY = "buy"
36 | SELL = "sell"
37 |
38 | @staticmethod
39 | def inverse(side):
40 | if side == OrderSide.BUY:
41 | return OrderSide.SELL
42 | elif side == OrderSide.SELL:
43 | return OrderSide.BUY
44 |
--------------------------------------------------------------------------------
/app/components/ohlc/data_adapters/api_adapter.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 |
4 | import pandas as pd
5 | import requests
6 |
7 | from components.ohlc import DataAdapter, OHLC
8 | from components.ohlc.symbol import Symbol
9 |
10 |
11 | class APIDataAdapter(DataAdapter):
12 |
13 | url = os.getenv('DATA_API_URL', None)
14 |
15 | def get_data(
16 | self,
17 | start: datetime = None,
18 | end: datetime = None,
19 | symbol: str = None,
20 | **kwargs
21 | ):
22 |
23 | if self.url is None:
24 | raise Exception('DATA_API_URL not set')
25 |
26 | # you will have to modify this to get the data from the API, this is just an example
27 | r = requests.get(f'{self.url}/data/{symbol}/ohlc/5?only_completed=true')
28 | r.raise_for_status()
29 | candles = r.json().get('candles', [])
30 |
31 | # strategy a dataframe from the candles
32 | df = pd.DataFrame.from_records(candles)
33 | df.set_index('timestamp', inplace=True)
34 |
35 | # strategy a symbol object
36 | symbol = Symbol(symbol)
37 |
38 | # finally, strategy the OHLC object and return it
39 | ohlc = OHLC(symbol=symbol, dataframe=df)
40 | return ohlc
41 |
42 |
--------------------------------------------------------------------------------
/app/components/positions/utils.py:
--------------------------------------------------------------------------------
1 | from components.orders.order import Order
2 | from components.orders.enums import OrderSide as Side
3 | from components.positions.enums import PositionEffect
4 | from components.positions.exceptions import PositionClosedException
5 |
6 |
7 | def add_closing_order_to_position(position, ohlc: 'OHLC'):
8 | """Adds a calculated closing order to the given position."""
9 | if position.closed:
10 | raise PositionClosedException('Position is already closed')
11 |
12 | # create the closing order
13 | order = Order(
14 | type='market',
15 | side=Side.inverse(position.get_side()),
16 | qty=position.get_size(),
17 | symbol=position.orders[0].symbol,
18 | filled_avg_price=ohlc.close,
19 | timestamp=ohlc.timestamp,
20 | )
21 |
22 | position.orders.append(order)
23 |
24 |
25 | def show_details(position: 'Position'):
26 | """Prints the details of the given position."""
27 | print('Position:', position)
28 | for order in position.orders:
29 | print('\tOrder:', order)
30 |
31 |
32 | def get_effect(position: 'Position', order: Order):
33 | """Get the effect of an order on a position."""
34 | if abs(position.size) < abs(position.size + order.qty):
35 | return PositionEffect.ADD
36 | else:
37 | return PositionEffect.REDUCE
38 |
--------------------------------------------------------------------------------
/app/components/orders/signals.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | from typing import Optional
3 |
4 | from pydantic import BaseModel
5 |
6 |
7 | class Signal(BaseModel):
8 | id: Optional[str] = None
9 | order_type: Optional[str] = None
10 | side: Optional[str] = None
11 | quantity: Optional[int] = None
12 | # symbol: Optional[str] = None
13 | price: Optional[float] = None
14 | timestamp: Optional[int] = None
15 |
16 | def from_position(self, position: 'Position'):
17 | self.order_type = position.orders[0].type
18 | self.side = position.get_side()
19 | self.quantity = position.orders[0].qty
20 | self.price = position.orders[0].filled_avg_price
21 | self.timestamp = position.orders[0].timestamp
22 | self.id = self.get_id()
23 | return self
24 |
25 | def get_id(self):
26 | return hashlib.md5(str(self).encode()).hexdigest()
27 |
28 |
29 | class BracketSignal(Signal):
30 | stop_loss: Optional[float] = None
31 | take_profit: Optional[float] = None
32 |
33 | def from_position(self, position: 'Position'):
34 | super().from_position(position)
35 | # TODO: add validation for these
36 | self.stop_loss = [o for o in position.orders if o.type == 'stop'][0].stop_price
37 | self.take_profit = [o for o in position.orders if o.type == 'limit'][0].limit_price
38 | return self
39 |
--------------------------------------------------------------------------------
/app/utils/loaders/load_all.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from pathlib import Path
3 | from importlib import import_module
4 |
5 | from loguru import logger
6 |
7 | from components.ohlc import DataAdapter
8 | from components.strategy.strategy import BaseStrategy
9 |
10 |
11 | def import_components(path, component_type):
12 | logger.debug(f'Importing {component_type.__name__}(s) from {path}...')
13 | components = []
14 | app_path = Path(__file__).parent.parent.parent
15 | paths = app_path.joinpath(path).rglob('*.py')
16 | for path in paths:
17 | module_name = path.as_posix().replace('/', '.').replace('.py', '').split('app.')[1]
18 | module = import_module(module_name)
19 |
20 | for name, obj in inspect.getmembers(module):
21 | if inspect.isclass(obj) and issubclass(obj, component_type):
22 | if obj.__name__ != component_type.__name__:
23 | components.append(obj)
24 | logger.info(f'\t->\t{obj.__name__} ({obj.__module__})')
25 | return components
26 |
27 | # load all components
28 | data_adapters = import_components('components/ohlc/data_adapters', DataAdapter)
29 | strategies = import_components('storage/strategies', BaseStrategy)
30 |
31 | # register all components
32 | for adapter in data_adapters:
33 | adapter.register()
34 |
35 | for strategy in strategies:
36 | strategy.register()
--------------------------------------------------------------------------------
/app/api/api_v1/endpoints/data.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | from fastapi import APIRouter
4 | from loguru import logger
5 | from pydantic import BaseModel
6 |
7 | from components.ohlc import DataAdapter, CSVAdapter
8 |
9 | router = APIRouter()
10 |
11 |
12 | class DataRequest(BaseModel):
13 | start: Union[int, None]
14 | end: Union[int, None]
15 | kwargs: dict
16 |
17 | class Config:
18 | schema_extra = {
19 | "example": {
20 | "start": None,
21 | "end": None,
22 | "kwargs": {
23 | "path": "data/AAPL.csv"
24 | }
25 | }
26 | }
27 |
28 |
29 | @router.post("/{adapter}")
30 | async def get_data(adapter: str, request: DataRequest):
31 | """Get data from an adapter
32 | May eventually be a "GET" request, see: https://github.com/swagger-api/swagger-ui/issues/2136
33 | """
34 | a = DataAdapter.objects.get(adapter)
35 | logger.debug(f'Data Endpoint: Using {a.name}')
36 | logger.debug(f'Data Endpoint: start:{request.start} end:{request.end} kwargs:{request.kwargs}')
37 | data = a.get_data(request.start, request.end, **request.kwargs)
38 | return data.to_dict()
39 |
40 |
41 | @router.get("/adapters", tags=["adapter"])
42 | async def get_adapters():
43 | """List all data adapters"""
44 | return [a.name for a in DataAdapter.objects.all()]
45 |
--------------------------------------------------------------------------------
/app/storage/strategies/examples/sma_cross_over.py:
--------------------------------------------------------------------------------
1 | from components import Parameter
2 | from components import Strategy, on_step
3 | from components.orders.order import Order
4 | from components.strategy import ta
5 | from components.strategy.decorators import after
6 | from components.strategy.strategy import Plot
7 |
8 |
9 | class SMACrossOver(Strategy):
10 | sma_fast_length = Parameter(10)
11 | sma_slow_length = Parameter(60)
12 |
13 | def __init__(self, *args, **kwargs):
14 | super().__init__(*args, **kwargs)
15 | all_close = self.data.all('close')
16 | self.sma_fast = ta.sma(all_close, int(self.sma_fast_length))
17 | self.sma_slow = ta.sma(all_close, int(self.sma_slow_length))
18 |
19 | @on_step
20 | def check_for_crossover(self):
21 | # add logic to crossover here
22 | cross_over = ta.logic.crossover(self.sma_fast, self.sma_slow)
23 | cross_under = ta.logic.crossunder(self.sma_fast, self.sma_slow)
24 | if cross_over:
25 | # self.orders.market_order(side='buy', quantity=1)
26 | self.positions.open(order_type='market', side='buy', quantity=1)
27 | elif cross_under:
28 | # self.orders.market_order(side='sell', quantity=1)
29 | self.positions.close()
30 |
31 | @after
32 | def create_plots(self):
33 | self.export_plots([
34 | Plot(self.sma_fast, name='sma_fast'),
35 | Plot(self.sma_slow, name='sma_slow'),
36 | ])
37 |
38 |
--------------------------------------------------------------------------------
/app/tests/test_strategy.py:
--------------------------------------------------------------------------------
1 | from components.ohlc import OHLC
2 | from components.strategy import Strategy
3 |
4 | CSV_PATH = 'tests/data/AAPL.csv'
5 |
6 |
7 | class TestStrategy:
8 | def test_strategy_name(self):
9 | assert Strategy().name == 'BaseStrategy'
10 |
11 | def test_initializing_examples(self):
12 | from storage.strategies.examples.sma_cross_over import SMACrossOver
13 | strategy = SMACrossOver()
14 | assert strategy.name == 'SMACrossOver'
15 | assert int(strategy.sma_fast_length) == 10
16 | assert int(strategy.sma_slow_length) == 60
17 |
18 | def test_run_strategy(self):
19 | from storage.strategies.examples.sma_cross_over import SMACrossOver
20 | ohlc = OHLC.from_csv(CSV_PATH, 'AAPL')
21 | strategy = SMACrossOver(data=ohlc)
22 | strategy.run(
23 | data=ohlc,
24 | )
25 |
26 | def test_ohlc_demo(self):
27 | from storage.strategies.examples.ohlc_demo import OHLCDemo
28 | strategy = OHLCDemo()
29 | ohlc = OHLC.from_csv(CSV_PATH, 'AAPL')
30 | strategy.run(
31 | data=ohlc,
32 | )
33 |
34 | def test_load_strategies(self):
35 | from utils.loaders.strategy_loader import import_all_strategies
36 | strategies = import_all_strategies()
37 | assert len(strategies) > 0
38 |
39 | def test_strategy_manager(self):
40 | assert len(Strategy.objects.all()) > 0
41 | assert Strategy.objects.get('SMACrossOver').name == 'SMACrossOver'
--------------------------------------------------------------------------------
/app/tests/test_order_manager.py:
--------------------------------------------------------------------------------
1 | from components.ohlc import OHLC, CSVAdapter
2 | from components.ohlc.symbol import Symbol
3 | from components.orders.order import Order
4 | from components.orders.enums import OrderType
5 | from components.orders.order_manager import OrderManager
6 | from components.strategy.strategy import BaseStrategy
7 | from storage.strategies.examples.sma_cross_over import SMACrossOver
8 |
9 | STRATEGY = SMACrossOver(
10 | data=CSVAdapter().get_data(start=None, end=None, path='tests/data/AAPL.csv', symbol='AAPL')
11 | )
12 | CLOSE = STRATEGY.data.close
13 | TIMESTAMP = STRATEGY.data.timestamp
14 | SYMBOL = STRATEGY.data.symbol.symbol
15 |
16 |
17 | class TestOrderManager:
18 | def test_market_order(self):
19 | om = OrderManager(STRATEGY)
20 | om.market_order(side='buy', quantity=1)
21 | om.market_order(side='sell', quantity=1)
22 | assert len(om) == 2
23 |
24 | def test_add(self):
25 | om = OrderManager(STRATEGY)
26 | om.add(
27 | Order(side='buy', qty=1, symbol=SYMBOL, filled_avg_price=CLOSE, timestamp=TIMESTAMP, type=OrderType.MARKET))
28 | om.add(Order(side='sell', qty=1, symbol=SYMBOL, filled_avg_price=CLOSE, timestamp=TIMESTAMP,
29 | type=OrderType.MARKET))
30 | assert len(om) == 2
31 |
32 | def test_summary(self):
33 | om = OrderManager(STRATEGY)
34 | om.market_order(side='buy', quantity=1)
35 | om.market_order(side='sell', quantity=1)
36 | assert om.summary() == {
37 | 'total': 2}
38 |
--------------------------------------------------------------------------------
/app/tests/test_ohlc.py:
--------------------------------------------------------------------------------
1 | from components.ohlc import OHLC
2 | from components.ohlc.symbol import Symbol
3 |
4 | PATH = 'tests/data/AAPL.csv'
5 |
6 |
7 | class TestOHLC:
8 |
9 | def test_from_csv(self):
10 | ohlc = OHLC.from_csv(PATH, 'AAPL')
11 | assert isinstance(ohlc, OHLC)
12 | assert isinstance(ohlc.symbol, Symbol)
13 | assert ohlc.symbol.symbol == 'AAPL'
14 | assert ohlc.shape == (5001, 5)
15 |
16 | def test_attr_forwarding(self):
17 | ohlc = OHLC.from_csv(PATH, 'AAPL')
18 | assert ohlc.shape == (5001, 5)
19 | assert ohlc.head().shape == (5, 5)
20 | assert ohlc.tail().shape == (5, 5)
21 | assert ohlc.describe().shape == (8, 5)
22 |
23 | def test_ohlc_getters(self):
24 | ohlc = OHLC.from_csv(PATH, 'AAPL')
25 | assert ohlc.open == 253.91
26 | assert ohlc.high == 257.33
27 | assert ohlc.low == 252.32
28 | assert ohlc.close == 257.33
29 | ohlc.advance_index()
30 | assert ohlc.open == 257.17
31 | assert ohlc.high == 257.67
32 | assert ohlc.low == 256.48
33 | assert ohlc.close == 257.07
34 |
35 | def test_index(self):
36 | ohlc = OHLC.from_csv(PATH, 'AAPL')
37 | assert ohlc._index == 0
38 | ohlc.advance_index()
39 | assert ohlc._index == 1
40 | ohlc.advance_index(2)
41 | assert ohlc._index == 3
42 | ohlc.reset_index()
43 | assert ohlc._index == 0
44 |
45 | def test_interpret_resolution(self):
46 | ohlc = OHLC.from_csv(PATH, 'AAPL')
47 | assert ohlc.resolution == 5
48 |
--------------------------------------------------------------------------------
/app/utils/loaders/strategy_loader.py:
--------------------------------------------------------------------------------
1 | # dynamically import all strategies in the storage/strategies folder
2 | import inspect
3 | from pathlib import Path
4 | from importlib import import_module
5 | from typing import List, Type
6 | from loguru import logger
7 |
8 | from components.strategy.strategy import BaseStrategy
9 |
10 | def import_all_strategies() -> List[Type[BaseStrategy]]:
11 | strategies = []
12 | app_path = Path(__file__).parent.parent.parent
13 | paths = app_path.joinpath('storage/strategies').rglob('*.py')
14 | for path in paths:
15 | # get the module name from the path
16 | module_name = path.as_posix().replace('/', '.').replace('.py', '')
17 | module_name = module_name.split('app.')[1]
18 | # import the module
19 | module = import_module(module_name)
20 | # get all classes in the module
21 | for name, obj in inspect.getmembers(module):
22 | # dynamically import all strategies in the storage/strategies folder
23 | if inspect.isclass(obj) and issubclass(obj, BaseStrategy):
24 | if obj.__name__ != 'BaseStrategy':
25 | strategies.append(obj)
26 | return strategies
27 |
28 | def register_all_strategies():
29 | strategies = import_all_strategies()
30 | for strategy in strategies:
31 | strategy.objects.register(strategy)
32 |
33 |
34 | loaded_strategies = import_all_strategies()
35 | logger.info(f'Imported {len(loaded_strategies)} strategies')
36 | for s in loaded_strategies:
37 | logger.info(f'\t->\t{s.__name__} ({s.__module__})')
38 |
39 | logger.info('Registering strategies')
40 | register_all_strategies()
41 | logger.info('Strategy loader finished')
--------------------------------------------------------------------------------
/app/tests/test_parameter.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from components.parameter import IntegerParameter, FloatParameter, BooleanParameter, Parameter
4 |
5 |
6 | class TestParameters:
7 |
8 | def test_create_parameters(self):
9 | make_int_param = Parameter(1, 0, 10)
10 | make_int_value_param = Parameter(1)
11 | make_float_param = Parameter(1.0, 0.0, 10.0)
12 | make_bool_param = Parameter(True)
13 | assert isinstance(make_int_param.value, IntegerParameter)
14 | assert isinstance(make_int_value_param.value, IntegerParameter)
15 | assert isinstance(make_float_param.value, FloatParameter)
16 | assert isinstance(make_bool_param.value, BooleanParameter)
17 |
18 | def test_initialize_parameters(self):
19 | int_param = IntegerParameter(1, 0, 10)
20 | float_param = FloatParameter(1.0, 0.0, 10.0)
21 | bool_param = BooleanParameter(True)
22 | assert int_param.value == 1
23 | assert int_param.min_value == 0
24 | assert int_param.max_value == 10
25 | assert float_param.value == 1.0
26 | assert float_param.min_value == 0.0
27 | assert float_param.max_value == 10.0
28 | assert bool_param.value == True
29 |
30 | with pytest.raises(ValueError):
31 | bad_int_param = IntegerParameter(11, 0, 10)
32 | bad_float_param = FloatParameter(11.0, 0.0, 10.0)
33 | bad_bool_param = BooleanParameter(1)
34 |
35 |
36 | def test_params_str(self):
37 | int_param = IntegerParameter(1, 0, 10)
38 | int_param.name = 'int_param'
39 | bool_param = BooleanParameter(True)
40 | bool_param.name = 'bool_param'
41 | assert str(int_param) == 'int_param : 1 (min_value=0, max_value=10)'
42 | assert str(bool_param) == 'bool_param : True'
43 |
44 |
45 |
--------------------------------------------------------------------------------
/app/tests/test_signals.py:
--------------------------------------------------------------------------------
1 | from components.orders.order import Order, StopOrder, LimitOrder
2 | from components.positions.positions import Position
3 | from components.orders.signals import Signal, BracketSignal
4 |
5 | ROOT_ORDER = Order(
6 | type='market',
7 | side='buy',
8 | qty=100,
9 | symbol='AAPL',
10 | filled_avg_price=100,
11 | timestamp=1000
12 | )
13 |
14 |
15 | class TestSignals:
16 | def test_basic_signal(self):
17 | p = Position(
18 | orders=[ROOT_ORDER],
19 | )
20 |
21 | p.test()
22 |
23 | s = Signal().from_position(p)
24 |
25 | assert s.order_type == 'market'
26 | assert s.side == 'buy'
27 | assert s.quantity == 100
28 | assert s.price == 100
29 |
30 | def test_bracket_signal(self):
31 | stop_order = StopOrder(
32 | type='stop',
33 | side='sell',
34 | qty=100,
35 | symbol='AAPL',
36 | stop_price=90,
37 | )
38 | limit_order = LimitOrder(
39 | type='limit',
40 | side='sell',
41 | qty=100,
42 | symbol='AAPL',
43 | limit_price=110,
44 | )
45 |
46 | p = Position(
47 | orders=[ROOT_ORDER, stop_order, limit_order],
48 | )
49 |
50 | s = BracketSignal().from_position(p)
51 |
52 | assert s.order_type == 'market'
53 | assert s.side == 'buy'
54 | assert s.price == 100
55 | assert s.stop_loss == 90
56 | assert s.take_profit == 110
57 |
58 | # check that serializing to JSON works as expected
59 | assert s.json() == ('{"id": "3b8cf71b56751e9e58f69ee2650cf483", "order_type": "market", "side": '
60 | '"buy", "quantity": 100, "price": 100.0, "timestamp": 1000, "stop_loss": '
61 | '90.0, "take_profit": 110.0}')
62 |
--------------------------------------------------------------------------------
/app/components/ohlc/data_adapters/adapter.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | import pandas as pd
4 | from loguru import logger
5 |
6 | from components.manager.manager import ComponentManager
7 | from components.ohlc import OHLC
8 |
9 |
10 | class DataAdapterManager(ComponentManager):
11 | _components = []
12 |
13 |
14 | class DataAdapter:
15 | """Base class for data adapters."""
16 |
17 | objects = DataAdapterManager
18 |
19 | @classmethod
20 | def register(cls):
21 | cls.objects.register(cls)
22 |
23 | def __init__(self):
24 | self.name = self.__class__.__name__
25 |
26 | # register
27 | self.register()
28 |
29 | def get_data(self, start: datetime = None, end: datetime = None, *args, **kwargs) -> OHLC:
30 | raise NotImplementedError
31 |
32 |
33 | class CSVAdapter(DataAdapter):
34 | """CSV Adapter, loads data from a csv file."""
35 |
36 | def get_data(
37 | self,
38 | start: datetime = None,
39 | end: datetime = None,
40 | path: str = None,
41 | symbol: str = None,
42 | ):
43 | """
44 | Loads data from a csv file.
45 | :param path: path to csv file
46 | :param symbol: symbol, as a string
47 | :param start: start timestamp
48 | :param end: end timestamp
49 | :return: OHLC object
50 | """
51 | from components.ohlc import OHLC
52 |
53 | if symbol is None:
54 | logger.warning('No symbol provided, using filename as symbol.')
55 | symbol = path.split('/')[-1].split('.')[0]
56 |
57 | ohlc = OHLC.from_csv(path, symbol)
58 |
59 | if start is not None or end is not None:
60 |
61 | # ensure start and end are timestamps
62 | if start is None:
63 | start = ohlc.index[0]
64 | if end is None:
65 | end = ohlc.index[-1]
66 |
67 | ohlc.trim(start, end)
68 |
69 | return ohlc
70 |
--------------------------------------------------------------------------------
/app/components/strategy/builtins/ta/correlation.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from components.strategy import Series
7 |
8 |
9 | def correlation_coefficient(x: Union[List[float], Series], y: Union[List[float], Series], period: int) -> Series:
10 | """ Calculate the correlation coefficient between two lists """
11 | if len(x) != len(y):
12 | raise ValueError("Input arrays must have the same length")
13 |
14 | if len(x) < period:
15 | raise ValueError("Period must be less than or equal to the length of input arrays")
16 |
17 | x = np.asarray(x)
18 | y = np.asarray(y)
19 |
20 | result = np.zeros(len(x) - period + 1)
21 |
22 | # calculate the rolling sums for both x and y.
23 | x_sum = np.cumsum(x)
24 | y_sum = np.cumsum(y)
25 |
26 | # calculate the rolling sums for x * y, x^2, and y^2.
27 | xy_sum = np.cumsum(x * y)
28 | x2_sum = np.cumsum(x**2)
29 | y2_sum = np.cumsum(y**2)
30 |
31 | for i in range(len(result)):
32 | if i == 0:
33 | x_sum_window = x_sum[period - 1]
34 | y_sum_window = y_sum[period - 1]
35 | xy_sum_window = xy_sum[period - 1]
36 | x2_sum_window = x2_sum[period - 1]
37 | y2_sum_window = y2_sum[period - 1]
38 | else:
39 | x_sum_window = x_sum[i + period - 1] - x_sum[i - 1]
40 | y_sum_window = y_sum[i + period - 1] - y_sum[i - 1]
41 | xy_sum_window = xy_sum[i + period - 1] - xy_sum[i - 1]
42 | x2_sum_window = x2_sum[i + period - 1] - x2_sum[i - 1]
43 | y2_sum_window = y2_sum[i + period - 1] - y2_sum[i - 1]
44 |
45 | # Calculate the correlation coefficient for the current window.
46 | numerator = period * xy_sum_window - x_sum_window * y_sum_window
47 | denominator = np.sqrt((period * x2_sum_window - x_sum_window**2) * (period * y2_sum_window - y_sum_window**2))
48 | result[i] = numerator / denominator
49 |
50 | pad_size = period - 1
51 | result = np.pad(result, (pad_size, 0), mode='constant', constant_values=np.nan)
52 | return Series(list(result))
--------------------------------------------------------------------------------
/app/components/orders/order_manager.py:
--------------------------------------------------------------------------------
1 | from components.orders.order import Order, StopOrder, LimitOrder
2 |
3 |
4 | class OrderManager:
5 | def __init__(self, strategy):
6 | self.orders = []
7 | self.strategy = strategy
8 | self.symbol = strategy.symbol
9 | if strategy is None:
10 | raise ValueError('Strategy is required')
11 |
12 |
13 | def market_order(self, side: str, quantity: int):
14 | order = Order(
15 | type='market',
16 | side=side,
17 | qty=quantity,
18 | symbol=self.symbol.symbol,
19 | filled_avg_price=self.strategy.data.close,
20 | timestamp=self.strategy.data.timestamp,
21 | )
22 | self.orders.append(order)
23 | return order
24 |
25 | def stop_loss_order(self, side: str, quantity: int, price: float):
26 | order = StopOrder(
27 | type='stop',
28 | side=side,
29 | qty=quantity,
30 | symbol=self.symbol.symbol,
31 | stop_price=price,
32 | timestamp=None,
33 | )
34 | self.orders.append(order)
35 | return order
36 |
37 | def limit_order(self, side: str, quantity: int, price: float):
38 | order = LimitOrder(
39 | type='limit',
40 | side=side,
41 | qty=quantity,
42 | symbol=self.symbol.symbol,
43 | limit_price=price,
44 | timestamp=None,
45 | )
46 | self.orders.append(order)
47 | return order
48 |
49 | def add(self, order: Order):
50 | self.orders.append(order)
51 |
52 | def all(self):
53 | return self.orders
54 |
55 | def filter(self, **kwargs):
56 | return [o for o in self.orders if all([o.__getattribute__(k) == v for k, v in kwargs.items()])]
57 |
58 | def __len__(self):
59 | return len(self.orders)
60 |
61 | def summary(self):
62 | return {
63 | 'total': len(self.orders),
64 | }
65 |
66 | def show(self):
67 | print('showing orders for strategy: {}'.format(self.strategy.name))
68 | print('\n'.join([str(o) for o in self.orders]))
--------------------------------------------------------------------------------
/app/storage/strategies/examples/using_builtins.py:
--------------------------------------------------------------------------------
1 | from components import Parameter
2 | from components import Strategy, on_step
3 | from components.orders.order import Order
4 | from components.strategy import ta
5 | from components.strategy.decorators import after
6 | from components.strategy.strategy import Plot
7 |
8 |
9 | class UsingBuiltins(Strategy):
10 | sma_fast_length = Parameter(10)
11 | sma_slow_length = Parameter(60)
12 |
13 | def __init__(self, *args, **kwargs):
14 | super().__init__(*args, **kwargs)
15 | all_close = self.data.all('close')
16 | if len(self.data) > 0:
17 | self.sma_faster = ta.sma(all_close, int(self.sma_fast_length) - 5)
18 | self.sma_fast = ta.sma(all_close, int(self.sma_fast_length))
19 | self.sma_slow = ta.sma(all_close, int(self.sma_slow_length))
20 | self.sma_slower = ta.sma(all_close, int(self.sma_slow_length) + 10)
21 | self.atr = ta.atr(self.data.all('high'), self.data.all('low'), all_close, 14)
22 | self.kf_1 = ta.kalman_filter(all_close, 600)
23 | self.kf_2 = ta.kalman_filter(all_close, 600)
24 | self.kf_3 = ta.kalman_filter(all_close, 600)
25 | self.kf_4 = ta.kalman_filter(all_close, 400)
26 | self.kf_5 = ta.kalman_filter(all_close, 500)
27 | self.kf_6 = ta.kalman_filter(all_close, 600)
28 | self.kf_7 = ta.kalman_filter(all_close, 700)
29 | self.correlation = ta.correlation_coefficient(self.sma_fast, self.sma_slow, 14)
30 |
31 | @on_step
32 | def check_for_crossover(self):
33 | # add logic to crossover here
34 | cross_over = ta.logic.crossover(self.sma_fast, self.sma_slow)
35 | cross_under = ta.logic.crossunder(self.sma_fast, self.sma_slow)
36 | corr = self.correlation > 0.5
37 | if cross_over:
38 | # self.orders.market_order(side='buy', quantity=1)
39 | self.positions.open(order_type='market', side='buy', quantity=1)
40 | elif cross_under:
41 | # self.orders.market_order(side='sell', quantity=1)
42 | self.positions.close()
43 |
44 | @after
45 | def create_plots(self):
46 | self.export_plots([
47 | Plot(self.sma_fast, name='sma_fast'),
48 | Plot(self.sma_slow, name='sma_slow'),
49 | ])
50 |
51 |
--------------------------------------------------------------------------------
/app/components/positions/position_manager.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from typing import List
3 |
4 | from components.ohlc.ohlc import FutureTimestampRequested
5 | from components.orders.order import Order
6 | from components.positions.positions import Position
7 | from components.positions.utils import add_closing_order_to_position
8 |
9 | from loguru import logger
10 |
11 |
12 | class PositionManager:
13 | def __init__(self, strategy: 'BaseStrategy'):
14 | self._strategy = strategy
15 | self.positions: List[Position] = []
16 |
17 | def add(self, position: Position):
18 | """Adds a position to the manager"""
19 | # TODO: add validation
20 | self.positions.append(position)
21 |
22 | def open(self, order_type: str, side: str, quantity: int):
23 | """Opens a new position"""
24 |
25 | # get the timestamp (it's offset by 1 because we're using the previous close)
26 | try:
27 | timestamp = self._strategy.data.get_timestamp(offset=1)
28 | except FutureTimestampRequested:
29 | logger.warning(f'{self._strategy} is opening a position in the future.')
30 | last_timestamp = self._strategy.data.get_timestamp(offset=0)
31 | resolution = self._strategy.data.resolution
32 | timestamp = last_timestamp + (resolution * 60 * 1000) # convert to milliseconds
33 | logger.warning(f'Extrapolating timestamp to {timestamp} ({datetime.datetime.fromtimestamp(timestamp)})'
34 | f'(resolution: {resolution})')
35 |
36 | # create order
37 | order = Order(
38 | type=order_type,
39 | side=side,
40 | qty=quantity,
41 | symbol=self._strategy.symbol.symbol,
42 | filled_avg_price=self._strategy.data.close,
43 | timestamp=timestamp
44 | )
45 |
46 | self.positions.append(
47 | Position(
48 | orders=[order],
49 | )
50 | )
51 |
52 | def close(self):
53 | """Closes the most recent position"""
54 | try:
55 | position_to_close = self.positions[-1]
56 | add_closing_order_to_position(position=position_to_close, ohlc=self._strategy.data)
57 | except IndexError:
58 | logger.error(f'{self._strategy} has no positions to close')
59 |
60 | def all(self):
61 | return self.positions
62 |
--------------------------------------------------------------------------------
/app/storage/strategies/examples/sma_cross_over_advanced.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | from components import Parameter
4 | from components import Strategy, on_step
5 | from components.orders.order import Order, LimitOrder, StopOrder
6 | from components.positions.positions import Position
7 | from components.strategy import ta
8 | from components.strategy.decorators import after
9 | from components.strategy.strategy import Plot
10 |
11 |
12 | class SMACrossOverAdvanced(Strategy):
13 | sma_fast_length = Parameter(10)
14 | sma_slow_length = Parameter(100)
15 |
16 | def __init__(self, *args, **kwargs):
17 | super().__init__(*args, **kwargs)
18 | all_close = self.data.all('close')
19 | self.sma_fast = ta.sma(all_close, int(self.sma_fast_length))
20 | self.sma_slow = ta.sma(all_close, int(self.sma_slow_length))
21 |
22 | @on_step
23 | def check_for_crossover(self):
24 | # add logic to crossover here
25 | cross_over = ta.logic.crossover(self.sma_fast, self.sma_slow)
26 | # filled timestamp must be set to "Now" + 5 minutes as the order is technically filled at the next candle
27 | filled_timestamp = datetime.datetime.fromtimestamp(self.data.timestamp / 1000) + datetime.timedelta(minutes=5)
28 | if cross_over:
29 | open_order = Order(
30 | type='market',
31 | side='buy',
32 | qty=100,
33 | symbol=self.symbol.symbol,
34 | filled_avg_price=self.data.close,
35 | timestamp=filled_timestamp.timestamp() * 1000,
36 | filled_timestamp=filled_timestamp.timestamp() * 1000,
37 | )
38 | take_profit = LimitOrder(
39 | side='sell',
40 | qty=100,
41 | symbol=self.symbol.symbol,
42 | limit_price=self.data.close + 1,
43 | )
44 | stop_loss = StopOrder(
45 | side='sell',
46 | qty=100,
47 | symbol=self.symbol.symbol,
48 | stop_price=self.data.close - 1,
49 | )
50 | p = Position(orders=[open_order, take_profit, stop_loss])
51 | self.positions.add(p)
52 |
53 | @after
54 | def create_plots(self):
55 | self.export_plots([
56 | Plot(self.sma_fast, name='sma_fast'),
57 | Plot(self.sma_slow, name='sma_slow'),
58 | ])
59 |
60 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
--------------------------------------------------------------------------------
/app/core/commands/strategy/commands.py:
--------------------------------------------------------------------------------
1 | import os
2 | import string
3 |
4 | from components.strategy import Strategy
5 | from core.commands.command import BaseCommand
6 | from utils.commons import STRATEGY_TEMPLATE_PATH
7 | from utils.create_file import create_file
8 | from utils.formatting import snake_case, pascal_case, pascal_to_snake_case
9 | from utils.wiki_link import wiki_link
10 | from rich import print
11 |
12 |
13 | class CreateNewStrategy(BaseCommand):
14 | """Creates a new strategy."""
15 |
16 | help = (
17 | "The name of the strategy, in pascal case, to strategy. (ThisIsPascalCase) "
18 | "This will be converted to snake_case and used as the filename. "
19 | "Example: 'MyStrategy' will be converted to 'my_strategy.py'."
20 | )
21 |
22 | template_path = STRATEGY_TEMPLATE_PATH
23 |
24 | strategy_name: str
25 |
26 | def _validate_strategy_name(self):
27 |
28 | validations = [
29 | self.strategy_name.isidentifier(),
30 | pascal_case(self.strategy_name).isidentifier(),
31 | pascal_to_snake_case(self.strategy_name).isidentifier(),
32 | ]
33 |
34 | if all(validations):
35 | return True
36 | msg = (
37 | f'Invalid strategy name: {self.strategy_name}. '
38 | f'Strategy names must valid. {wiki_link("https://github.com/robswc/stratis/wiki/Strategies#naming")}'
39 | )
40 | raise ValueError(msg)
41 |
42 | def __init__(self, strategy_name: str):
43 | super().__init__()
44 | self.strategy_name = strategy_name
45 | self._validate_strategy_name()
46 |
47 | def handle(self, *args, **kwargs):
48 | if kwargs.get('override_path'):
49 | strategy_path = f'{kwargs.get("override_path")}/{pascal_to_snake_case(self.strategy_name)}.py'
50 | else:
51 | strategy_path = f'storage/strategies/{pascal_to_snake_case(self.strategy_name)}.py'
52 |
53 | # Open the template file and read its contents
54 | with open(self.template_path, 'r') as template_file:
55 | template_contents = template_file.read()
56 |
57 | # Replace the class name in the template contents with the provided name
58 | new_contents = string.Template(template_contents).substitute(
59 | StrategyName=self.strategy_name
60 | )
61 |
62 | # strategy the file
63 | path, created = create_file(strategy_path, new_contents, prompt=kwargs.get('prompt', True))
64 | if created:
65 | print(f'[bold green]Created new strategy at:[/bold green] {path}')
66 | else:
67 | print(f'[bold red]Strategy Not Created[/bold red]')
68 |
69 | class ListStrategies(BaseCommand):
70 |
71 | help = (
72 | "Lists all registered strategies"
73 | )
74 |
75 | def handle(self, *args, **kwargs):
76 | return [s.name for s in Strategy.objects.all()]
--------------------------------------------------------------------------------
/app/tests/test_order.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | import pytest
4 | from pydantic import ValidationError
5 |
6 | from components.orders.order import Order
7 | from components.orders.enums import OrderType, OrderSide as Side
8 | from components.positions.positions import Position
9 |
10 | TIMESTAMP = int(datetime.now().timestamp())
11 | DETERMINISTIC_TIMESTAMP = 1610000000
12 |
13 |
14 | class TestOrders:
15 |
16 | def test_hashing(self):
17 | fake_order_1 = Order(
18 | symbol='BTCUSDT',
19 | side=Side.BUY,
20 | type=OrderType.MARKET,
21 | qty=1,
22 | timestamp=TIMESTAMP,
23 | )
24 | fake_order_2 = Order(
25 | symbol='BTCUSDT',
26 | side=Side.BUY,
27 | type=OrderType.MARKET,
28 | qty=1,
29 | timestamp=TIMESTAMP,
30 | )
31 | assert fake_order_1 == fake_order_2
32 |
33 | fake_order_3 = Order(
34 | symbol='BTCUSDT',
35 | side=Side.BUY,
36 | type=OrderType.MARKET,
37 | qty=2,
38 | timestamp=TIMESTAMP,
39 | )
40 |
41 | assert fake_order_1 != fake_order_3
42 |
43 | # test IDs
44 | assert fake_order_1.get_id() == fake_order_2.get_id()
45 | assert fake_order_1.get_id() != fake_order_3.get_id()
46 |
47 | def test_order_validation(self):
48 | # valid orders
49 | Order(
50 | symbol='BTCUSDT',
51 | side=Side.BUY,
52 | type=OrderType.MARKET,
53 | timestamp=TIMESTAMP,
54 | qty=1,
55 | )
56 |
57 | # pytest, test several invalid orders to ensure they all raise a ValidationError
58 | invalid_orders = [
59 | # missing symbol
60 | {"side": Side.BUY, "type": OrderType.MARKET, "timestamp": TIMESTAMP, "qty": 1},
61 | # missing side
62 | {"symbol": "BTCUSDT", "type": OrderType.MARKET, "timestamp": TIMESTAMP, "qty": 1},
63 | # missing qty
64 | {"symbol": "BTCUSDT", "side": Side.BUY, "type": OrderType.MARKET, "timestamp": TIMESTAMP},
65 | ]
66 |
67 | for data in invalid_orders:
68 | with pytest.raises(ValidationError):
69 | Order(**data)
70 |
71 |
72 | class TestPosition:
73 | def test_position(self):
74 | fake_order_1 = Order(
75 | symbol='BTCUSDT',
76 | side=Side.BUY,
77 | type=OrderType.MARKET,
78 | qty=1,
79 | timestamp=DETERMINISTIC_TIMESTAMP,
80 | )
81 | fake_order_2 = Order(
82 | symbol='BTCUSDT',
83 | side=Side.SELL,
84 | type=OrderType.MARKET,
85 | qty=1,
86 | timestamp=DETERMINISTIC_TIMESTAMP + 30000,
87 | )
88 | p = Position(orders=[fake_order_1, fake_order_2])
89 | assert p._get_id() == "0dbf110bd4db94de539295c65867705d"
--------------------------------------------------------------------------------
/app/api/api_v1/endpoints/strategy.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List
3 |
4 | from fastapi import APIRouter
5 | from pydantic import BaseModel
6 | from starlette.responses import Response
7 |
8 | from components import Strategy
9 | from components.backtest.backtest import BacktestResult
10 | from components.ohlc import DataAdapter
11 | from loguru import logger
12 |
13 | router = APIRouter()
14 |
15 |
16 | @router.get("/strategy")
17 | async def list_all_strategies():
18 | """List all strategies"""
19 | strategies = Strategy.objects.all()
20 | return [s.as_model() for s in strategies]
21 |
22 |
23 | @router.get("/")
24 | async def get_strategy(name: str):
25 | """Get a strategy by name"""
26 | try:
27 | strategy = Strategy.objects.get(name=name)
28 | return strategy.as_model()
29 | except ValueError:
30 | return Response(status_code=404)
31 |
32 |
33 | class RunStrategyResponse(BaseModel):
34 | backtest: BacktestResult
35 | plots: List[dict]
36 |
37 |
38 | class RunStrategyRequest(BaseModel):
39 | strategy: str
40 | parameters: dict
41 | adapter: str
42 | adapter_kwargs: dict
43 |
44 | class Config:
45 | schema_extra = {
46 | "example": {
47 | "strategy": "SMACrossOver",
48 | "parameters": {},
49 | "adapter": "CSVAdapter",
50 | "adapter_kwargs": {
51 | "path": "tests/data/AAPL.csv"}
52 | }
53 | }
54 |
55 |
56 | @router.post("/", response_model=RunStrategyResponse)
57 | async def run_strategy(request: RunStrategyRequest):
58 | """Run a strategy with data"""
59 |
60 | # get arguments from request
61 | name = request.strategy
62 | data_adapter_name = request.adapter
63 | data_adapter_kwargs = request.adapter_kwargs
64 | parameters = request.parameters if request.parameters else {}
65 |
66 | # get strategy and data adapter
67 | try:
68 | da = DataAdapter.objects.get(name=data_adapter_name)
69 | except ValueError:
70 | return Response(status_code=404, content="Data adapter not found")
71 | try:
72 | strategy = Strategy.objects.get(name=name)
73 | except ValueError:
74 | return Response(status_code=404, content="Strategy not found")
75 |
76 | ohlc = da.get_data(**data_adapter_kwargs)
77 |
78 | backtest_result, plots = strategy.run(data=ohlc, parameters=parameters, plots=True)
79 | logger.info(f'Backtest result: {backtest_result.get_overview()}')
80 | logger.info(f'Plots: ({len(plots)})')
81 | return RunStrategyResponse(backtest=backtest_result, plots=[p.as_dict() for p in plots])
82 |
83 |
84 | class SignalsRequest(BaseModel):
85 | signal_type: str
86 | strategy: RunStrategyRequest
87 |
88 |
89 | @router.post("/signals")
90 | async def run_signals(request: SignalsRequest):
91 | """Run signals"""
92 | print(request.signal_type)
93 | pass
94 |
--------------------------------------------------------------------------------
/app/components/parameter.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from pydantic import BaseModel
4 |
5 |
6 | class ParameterModel(BaseModel):
7 | name: str
8 | value: Union[bool, int, float, str, None]
9 |
10 |
11 | class BaseParameter:
12 | def __init__(self):
13 | self.name = None
14 | self._validate()
15 |
16 | def _validate(self):
17 | raise NotImplementedError
18 |
19 | def __str__(self):
20 | value = self.__getattribute__('value')
21 | kwargs = [f'{k}={v}' for k, v in self.__dict__.items() if k != 'name' and k != 'value']
22 | kwargs_str = '' if len(kwargs) == 0 else f' ({", ".join(kwargs)})'
23 | return f'{self.name} : {value}{kwargs_str}'
24 |
25 | def as_model(self):
26 | return ParameterModel(
27 | name=self.name,
28 | value=self.value
29 | )
30 |
31 |
32 | class Parameter:
33 | def __init__(self, value, *args, **kwargs):
34 | if isinstance(value, bool):
35 | self.value = BooleanParameter(value)
36 | elif isinstance(value, int):
37 | self.value = IntegerParameter(value, *args, **kwargs)
38 | elif isinstance(value, float):
39 | self.value = FloatParameter(value, *args, **kwargs)
40 | else:
41 | raise ValueError('Invalid parameter type')
42 |
43 | def __index__(self):
44 | return self.value.__index__()
45 |
46 |
47 |
48 |
49 | class IntegerParameter(BaseParameter):
50 | def __init__(self, value: int, min_value: int = 0, max_value: int = 9999):
51 | self.value = int(value)
52 | self.min_value = min_value
53 | self.max_value = max_value
54 | super().__init__()
55 |
56 | def __int__(self):
57 | return self.value
58 |
59 | def __index__(self):
60 | return int(self)
61 |
62 | def _validate(self):
63 | if self.value < self.min_value or self.value > self.max_value:
64 | raise ValueError(f'{self} must be between {self.min_value} and {self.max_value}')
65 |
66 |
67 | class FloatParameter(BaseParameter):
68 | def __init__(self, value: float, min_value: float = -9999, max_value: float = 9999):
69 | self.value = float(value)
70 | self.min_value = min_value
71 | self.max_value = max_value
72 | super().__init__()
73 |
74 | def __float__(self):
75 | return self.value
76 |
77 | def __int__(self):
78 | return int(self.value)
79 |
80 | def _validate(self):
81 | if self.value < self.min_value or self.value > self.max_value:
82 | raise ValueError(f'{self} must be between {self.min_value} and {self.max_value}')
83 | # ensure value is a float
84 | self.value = float(self.value)
85 |
86 | def __index__(self):
87 | return float(self)
88 |
89 |
90 | class BooleanParameter(BaseParameter):
91 | def __init__(self, value: bool):
92 | self.value = bool(value)
93 | super().__init__()
94 |
95 | def __bool__(self):
96 | return self.value
97 |
98 | def __index__(self):
99 | return bool(self)
100 |
101 | def _validate(self):
102 | self.value = bool(self.value)
103 |
--------------------------------------------------------------------------------
/app/components/strategy/series.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Union
3 |
4 | import pandas as pd
5 |
6 |
7 | # eventually create a base series and have different types of series
8 | class Series:
9 | def __init__(self, data: Union[list, pd.Series]):
10 | self._loop_index = 0
11 | self._data = data
12 | self._index_func = {
13 | list: lambda: data[self._loop_index],
14 | pd.Series: lambda: data.iat[self._loop_index]
15 | }[type(self._data)]
16 | self._shift_func = {
17 | list: lambda n: data[self._loop_index - n],
18 | pd.Series: lambda n: data.iat[self._loop_index - n]
19 | }[type(self._data)]
20 |
21 | def advance_index(self):
22 | self._loop_index += 1
23 |
24 | def as_list(self):
25 | if isinstance(self._data, list):
26 | return self._data
27 | if isinstance(self._data, pd.Series):
28 | df = self._data.copy()
29 | # replace NaN with previous value
30 | df.fillna(method='backfill', inplace=True)
31 | return df.tolist()
32 |
33 | def __repr__(self):
34 | return str(float(self))
35 |
36 | def __len__(self):
37 | return len(self._data)
38 |
39 | def __getitem__(self, item):
40 | if isinstance(self._data, list):
41 | return self._data[item]
42 | if isinstance(self._data, pd.Series):
43 | return self._data.iat[item]
44 |
45 | def __float__(self):
46 | return float(self._index_func())
47 |
48 | def shift(self, n=1):
49 | return float(self._shift_func(n))
50 |
51 | def __int__(self):
52 | return int(float(self))
53 |
54 | def __add__(self, other):
55 | return float(self) + other
56 |
57 | def __sub__(self, other):
58 | return float(self) - other
59 |
60 | def __mul__(self, other):
61 | return float(self) * other
62 |
63 | def __truediv__(self, other):
64 | return float(self) / other
65 |
66 | def __floordiv__(self, other):
67 | return float(self) // other
68 |
69 | def __mod__(self, other):
70 | return float(self) % other
71 |
72 | def __pow__(self, other):
73 | return float(self) ** other
74 |
75 | def __lt__(self, other):
76 | return float(self) < float(other)
77 |
78 | def __le__(self, other):
79 | return float(self) <= other
80 |
81 | def __eq__(self, other):
82 | return float(self) == other
83 |
84 | def __ne__(self, other):
85 | return float(self) != other
86 |
87 | def __gt__(self, other):
88 | return float(self) > float(other)
89 |
90 | def __ge__(self, other):
91 | return float(self) >= other
92 |
93 | def __and__(self, other):
94 | return float(self) and other
95 |
96 | def __or__(self, other):
97 | return float(self) or other
98 |
99 | def __neg__(self):
100 | return -float(self)
101 |
102 | def __pos__(self):
103 | return +float(self)
104 |
105 | def __abs__(self):
106 | return abs(float(self))
107 |
108 | def __invert__(self):
109 | return ~float(self)
110 |
111 | def __round__(self, n=None):
112 | return round(float(self), n)
113 |
114 | def __trunc__(self):
115 | return math.trunc(float(self))
116 |
117 | def __floor__(self):
118 | return math.floor(float(self))
119 |
120 | def __ceil__(self):
121 | return math.ceil(float(self))
122 |
123 | def __index__(self):
124 | return float(self)
125 |
126 | def __radd__(self, other):
127 | return other + float(self)
128 |
129 | def __rsub__(self, other):
130 | return other - float(self)
131 |
132 | def __rmul__(self, other):
133 | return other * float(self)
134 |
135 | def __rtruediv__(self, other):
136 | return other / float(self)
137 |
--------------------------------------------------------------------------------
/app/components/backtest/backtest.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import queue
3 | from typing import Optional, List, Union
4 |
5 | from loguru import logger
6 | from pydantic import BaseModel
7 |
8 | from components.backtest.utils import remove_overlapping_positions
9 | from components.orders.order import Order
10 | from components.positions.positions import Position
11 |
12 |
13 | def get_effect(position: Position, order: Order):
14 | """Get the effect of an order on a position."""
15 | if position.get_size() > 0:
16 | if order.side == 'buy':
17 | return 'increase'
18 | elif order.side == 'sell':
19 | return 'decrease'
20 | elif position.get_size() < 0:
21 | if order.side == 'buy':
22 | return 'decrease'
23 | elif order.side == 'sell':
24 | return 'increase'
25 | else:
26 | return 'increase'
27 |
28 |
29 | def worker(q, data):
30 | while not q.empty():
31 | try:
32 | position = q.get_nowait()
33 | position.test(data)
34 | except queue.Empty:
35 | break
36 | finally:
37 | q.task_done()
38 |
39 |
40 | class BacktestResult(BaseModel):
41 | pnl: float
42 | wl_ratio: float
43 | sharpe_ratio: Optional[float]
44 | max_drawdown: Optional[float]
45 | max_drawdown_duration: Optional[int]
46 | trades: int
47 | winning_trades: int
48 | losing_trades: int
49 | positions: List[Position]
50 | orders: List[Order]
51 |
52 | def get_overview(self):
53 | """Get a dict with the most important backtest results."""
54 | return {
55 | 'pnl': self.pnl,
56 | 'wl_ratio': self.wl_ratio,
57 | 'trades': self.trades,
58 | 'winning_trades': self.winning_trades,
59 | 'losing_trades': self.losing_trades,
60 | }
61 |
62 |
63 | class Backtest:
64 | def __init__(self, data, strategy):
65 | self.data = data
66 | self.strategy = strategy
67 | self.result: Union[BacktestResult, None] = None
68 |
69 | def _get_orders_with_filled_timestamp(self, orders):
70 | return [o for o in orders if o.filled_timestamp is not None]
71 |
72 | def _sort_orders(self, orders: List[Order]):
73 | return sorted(orders, key=lambda x: x.timestamp)
74 |
75 | def test(self):
76 | logger.debug(f'Starting backtest for strategy {self.strategy.name}')
77 |
78 | positions = self.strategy.positions.all()
79 | orders = self.strategy.orders.all()
80 |
81 | # create a bounded queue to hold the positions
82 | position_queue = queue.Queue()
83 | for position in positions:
84 | position_queue.put(position)
85 |
86 | # use concurrent futures to test orders in parallel
87 | logger.debug(f'Testing {len(positions)} positions in parallel...')
88 | with concurrent.futures.ThreadPoolExecutor() as executor:
89 | for _ in range(len(positions)):
90 | executor.submit(worker, position_queue, self.data)
91 |
92 | # wait for all positions to be tested
93 | position_queue.join()
94 |
95 | logger.debug(f'Finished testing {len(positions)} positions in parallel.')
96 |
97 | # after all positions have been tested, we can check for overlapping positions
98 | positions = remove_overlapping_positions(positions, max_overlap=0)
99 |
100 | all_position_orders = []
101 | for p in positions:
102 | all_position_orders += p.orders
103 |
104 | # calculate win/loss ratio
105 | losing_trades = len([p for p in positions if p.pnl < 0])
106 | winning_trades = len([p for p in positions if p.pnl > 0])
107 | if losing_trades == 0:
108 | wl_ratio = 1
109 | elif winning_trades == 0:
110 | wl_ratio = 0
111 | else:
112 | wl_ratio = round(winning_trades / (winning_trades + losing_trades), 2)
113 |
114 | # filter and sort orders
115 | result_orders = self._get_orders_with_filled_timestamp(orders + all_position_orders)
116 | # sort ascending by timestamp
117 | result_orders = self._sort_orders(result_orders)
118 |
119 | # create backtest result
120 | self.result = BacktestResult(
121 | pnl=sum([position.pnl for position in positions]),
122 | wl_ratio=wl_ratio,
123 | trades=len(positions),
124 | winning_trades=winning_trades,
125 | losing_trades=losing_trades,
126 | positions=positions,
127 | orders=result_orders,
128 | )
129 |
--------------------------------------------------------------------------------
/app/components/orders/order.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import hashlib
3 | from typing import Optional, Union
4 |
5 | from loguru import logger
6 | from pydantic import BaseModel, ValidationError, validator
7 |
8 | from components.orders.enums import TimeInForce, OrderType, OrderSide
9 |
10 | """
11 | Similar to Alpaca-py
12 | """
13 |
14 |
15 | class OrderValidationError(Exception):
16 | pass
17 |
18 |
19 | class Order(BaseModel):
20 | id: Optional[str]
21 | timestamp: Union[int, None]
22 | symbol: str
23 | qty: int
24 | notional: Optional[float]
25 | side: OrderSide
26 | type: Optional[OrderType]
27 | time_in_force: TimeInForce = TimeInForce.GTC
28 | extended_hours: Optional[bool]
29 | client_order_id: Optional[str]
30 | filled_avg_price: Optional[float]
31 | filled_timestamp: Optional[int]
32 | did_not_fill: Optional[bool] = False
33 |
34 | @staticmethod
35 | def create_market_order(symbol: str, qty: float, side: OrderSide):
36 | return Order(
37 | symbol=symbol,
38 | qty=qty,
39 | side=side,
40 | type=OrderType.MARKET,
41 | )
42 |
43 | def _timestamp_to_datetime(self, timestamp: int):
44 | if timestamp is not None:
45 | return datetime.datetime.fromtimestamp(timestamp / 1000).strftime('%Y-%m-%d %H:%M:%S')
46 | else:
47 | return 'TBD'
48 |
49 | def __str__(self):
50 | order_type = OrderType.abbreviation(self.type).upper()
51 | side = self.side.upper()
52 | return f'{order_type} {side} [{abs(self.qty)}] {self.symbol} @ {self.filled_avg_price}\t' \
53 | f'({self._timestamp_to_datetime(self.timestamp)})'
54 |
55 | def __hash__(self):
56 | h = hashlib.sha256(f'{self.timestamp}{self.symbol}{self.qty}{self.side}{self.type}'.encode())
57 | return int(h.hexdigest(), 16)
58 |
59 | def get_id(self):
60 | return hashlib.md5(str(hash(self)).encode()).hexdigest()
61 |
62 | # will eventually make this a proper attribute
63 | @property
64 | def price(self):
65 | return self.filled_avg_price
66 |
67 | class Config:
68 | schema_extra = {
69 | "example": {
70 | "id": "b6b6b6b6-b6b6-b6b6-b6b6-b6b6b6b6b6b6",
71 | "symbol": "AAPL",
72 | "qty": 100,
73 | "side": "buy",
74 | "type": "market",
75 | }
76 | }
77 |
78 | def __init__(self, **data):
79 | try:
80 | super().__init__(**data)
81 | except ValidationError as e:
82 | logger.error(e)
83 | # logger.exception(e)
84 | raise e
85 |
86 | # if valid, set id
87 | self.id = self.get_id()
88 | self.type = self.type or OrderType.MARKET
89 | self.filled_timestamp = self.timestamp
90 |
91 | # if side is sell, qty must be negative
92 | if self.side == OrderSide.SELL:
93 | self.qty = -abs(self.qty)
94 | else:
95 | self.qty = abs(self.qty)
96 |
97 | @validator('qty')
98 | def qty_must_be_int(cls, v):
99 | assert v != 0, 'qty cannot be 0'
100 | return v
101 |
102 |
103 | class LimitOrder(Order):
104 | limit_price: float
105 |
106 | @property
107 | def price(self):
108 | return self.limit_price
109 |
110 | class Config:
111 | schema_extra = {
112 | "example": {
113 | "id": "b6b6b6b6-b6b6-b6b6-b6b6-b6b6b6b6b6b6",
114 | "symbol": "AAPL",
115 | "qty": 100,
116 | "side": "buy",
117 | "type": "limit",
118 | "limit_price": 100.00,
119 | }
120 | }
121 |
122 | def __init__(self, **data):
123 | super().__init__(**data)
124 | self.type = OrderType.LIMIT
125 |
126 | def __str__(self):
127 | order_type = OrderType.abbreviation(self.type).upper()
128 | side = self.side.upper()
129 | return f'{order_type} {side} [{self.qty}] {self.symbol} @ {self.limit_price}\t({self._timestamp_to_datetime(self.timestamp)})'
130 |
131 |
132 | class StopOrder(Order):
133 | stop_price: Optional[float]
134 |
135 | @property
136 | def price(self):
137 | return self.stop_price
138 |
139 | class Config:
140 | schema_extra = {
141 | "example": {
142 | "id": "b6b6b6b6-b6b6-b6b6-b6b6-b6b6b6b6b6b6",
143 | "symbol": "AAPL",
144 | "qty": 100,
145 | "side": "buy",
146 | "type": "stop",
147 | "stop_price": 100.00,
148 | }
149 | }
150 |
151 | def __init__(self, **data):
152 | super().__init__(**data)
153 | self.type = OrderType.STOP
154 |
155 | def __str__(self):
156 | order_type = OrderType.abbreviation(self.type).upper()
157 | side = self.side.upper()
158 | return f'{order_type} {side} [{self.qty}] {self.symbol} @ {self.stop_price}\t({self._timestamp_to_datetime(self.timestamp)})'
159 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | # Stratis
8 |
9 | [](https://github.com/robswc/stratis/blob/master/LICENSE)
10 | []()
11 | [](https://github.com/robswc/stratis)
12 | [](https://github.com/robswc/stratis/stargazers)
13 | [](https://twitter.com/robswc)
14 |
15 |
16 |
17 |
18 | Stratis is a python-based framework for developing and testing strategies, inspired by the simplicity
19 | of tradingview's [Pinescript](https://www.tradingview.com/pine-script-docs/en/v5/Introduction.html). Currently,
20 | stratis is in the early stages of development, and is not yet ready for production use. However, it is encouraged
21 | to try it out and provide feedback via [GitHub issues](https://github.com/robswc/stratis/issues/new).
22 |
23 | Stratis is a part of [Shenandoah Research's](https://shenandoah.capital/) Open source Trading Software Initiative and developed by Polyad Decision Sciences software engineers!
24 |
25 |
26 |
27 |
28 |
29 |
30 | #### _Please Note Stratis is under active development and is not yet ready for production use._
31 |
32 |
33 | ## Basic Example
34 |
35 | The following code demonstrates how to create a strategy that prints the timestamp and close price of the
36 | OHLC data every hour. The `on_step` decorator is used to run the function on every step of the OHLC data. You can find
37 | more info about how to create strategies [here](https://github.com/robswc/stratis/wiki/Strategies).
38 |
39 | Using the `Strategy`
40 | class, along with the `on_step` decorator, you can create strategies that are as simple or as complex as you want, with
41 | the full power of python at your disposal.
42 |
43 | ```python
44 | class OHLCDemo(Strategy):
45 |
46 | @on_step # on_step decorated, runs every "step" of the OHLC data
47 | def print_ohlc(self):
48 |
49 | # shorthands for the OHLC data
50 | timestamp = self.data.timestamp
51 | close = self.data.close
52 |
53 | # if the timestamp is a multiple of 3600000 (1 hour)
54 | if timestamp % 3600000 == 0:
55 | # create a datetime object from the timestamp
56 | dt = datetime.datetime.fromtimestamp(timestamp / 1000)
57 | if dt.hour == 10:
58 | print(f'{dt}: {close}')
59 | ```
60 |
61 | ```python
62 | data = CSVAdapter('data/AAPL.csv')
63 | strategy = OHLCDemo().run(data)
64 | ```
65 |
66 |
67 | ## Table of Contents
68 |
69 | - [Installation](#Installation)
70 | - [Docker](#Docker)
71 | - [Python and NPM](#Python-and-NPM)
72 |
73 | [//]: # (- [Features](#features))
74 |
75 | ## Installation
76 |
77 | It is heavily recommended to use [Docker](https://www.docker.com/resources/what-container/) to run stratis. This is because stratis requires a number of dependencies that
78 | can be difficult to install without Docker.
79 |
80 |
81 | ### Docker
82 | To install Docker, follow the instructions [here](https://docs.docker.com/get-docker/).
83 |
84 | Once Docker is installed, you can run stratis by running the following commands:
85 |
86 | ```bash
87 | # Clone the repository
88 | git clone https://github.com/robswc/stratis
89 |
90 | # Change directory to the repository
91 | cd stratis
92 |
93 | # Run the docker-compose file
94 | docker-compose up -d # -d runs the containers in the background
95 | ```
96 |
97 | The Stratis UI interface should now be accessible via:
98 |
99 | [http://localhost:3000](http://localhost:3000)
100 |
101 | And the Stratis backend (core) should be accessible via:
102 |
103 | [http://localhost:8000](http://localhost:8000)
104 |
105 | ### Python and NPM
106 |
107 | For more advanced usage, you can run app with python directly, as it is a FastAPI app under the hood.
108 | Please note, this may or may not work for Windows and MacOS, as I have only tested it on Linux.
109 |
110 | I would recommend using a [virtual environment](https://docs.python.org/3/library/venv.html) for this.
111 | You will also have to [install the requirements](https://pip.pypa.io/en/latest/user_guide/#requirements-files).
112 | The following commands will start the backend of stratis.
113 |
114 | ```bash
115 | cd app # change directory to the app folder
116 | python python -m uvicorn main:app --reload # reloads the app on file changes (useful for development)
117 | ```
118 |
119 | The frontend of stratis is a NextJS app. The repository for the frontend can be found [here](https://github.com/robswc/stratis-ui).
--------------------------------------------------------------------------------
/app/components/strategy/strategy.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import List, Union
3 |
4 | import pandas as pd
5 | from loguru import logger
6 | from pydantic import BaseModel
7 |
8 | from components.backtest.backtest import Backtest
9 | from components.manager.manager import ComponentManager
10 | from components.ohlc import OHLC
11 | from components.orders.order_manager import OrderManager
12 | from components.positions.position_manager import PositionManager
13 | from components.parameter import BaseParameter, Parameter, ParameterModel
14 | from components.strategy.decorators import extract_decorators
15 |
16 |
17 | class PlotConfig(BaseModel):
18 | color: str = 'blue'
19 | type: str = 'line'
20 | lineWidth: int = 1
21 |
22 |
23 | class Plot:
24 | def __init__(self, series: 'Series', **kwargs):
25 | self.data = series.as_list()
26 | self.name = kwargs.get('name', None)
27 | self.config = PlotConfig(**kwargs)
28 |
29 | def as_dict(self):
30 | return {
31 | 'name': self.name,
32 | 'data': self.data,
33 | 'config': self.config.dict()
34 | }
35 |
36 |
37 | class StrategyManager(ComponentManager):
38 | _components = []
39 |
40 |
41 | class StrategyModel(BaseModel):
42 | name: str
43 | parameters: List[ParameterModel]
44 |
45 |
46 | class BaseStrategy:
47 | objects = StrategyManager
48 |
49 | @classmethod
50 | def register(cls):
51 | cls.objects.register(cls)
52 |
53 | def __init__(self, data: Union[OHLC, None] = None):
54 | self.name = self.__class__.__name__
55 |
56 | # handle data
57 | if data is None:
58 | data = OHLC()
59 | self.data = data
60 | self._loop_index = 0
61 |
62 | # create a shortcut to the symbol
63 | print(self.data.symbol)
64 | self.symbol = data.symbol
65 |
66 | # handle parameters
67 | self.parameters: List[BaseParameter] = []
68 | self._set_parameters()
69 |
70 | self.register()
71 |
72 | # strategy decorators
73 | self._step_methods = []
74 | self._before_methods = []
75 | self._after_methods = []
76 |
77 | befores, steps, afters = extract_decorators(self)
78 | self._before_methods = befores
79 | self._step_methods = steps
80 | self._after_methods = afters
81 |
82 | # each strategy gets a new order and position manager
83 | self.orders = OrderManager(self) # eventually all orders will be converted to positions
84 | self.positions = PositionManager(self)
85 |
86 | # each strategy gets plots
87 | self.plots = []
88 |
89 | def export_plots(self, plots: List[Plot]):
90 | self.plots = plots
91 |
92 | def as_model(self) -> StrategyModel:
93 | return StrategyModel(
94 | name=self.name,
95 | parameters=[p.as_model() for p in self.parameters],
96 | )
97 |
98 | def _get_all_parameters(self):
99 | parameters = []
100 | for attr in dir(self):
101 | if isinstance(self.__getattribute__(attr), BaseParameter):
102 | self.parameters.append(self.__getattribute__(attr))
103 | if isinstance(self.__getattribute__(attr), Parameter):
104 | p = self.__getattribute__(attr).value
105 | p.name = attr
106 | parameters.append(p)
107 | return parameters
108 |
109 | def _set_parameters(self):
110 | # find all parameters in the class
111 | for attr in dir(self):
112 | if isinstance(self.__getattribute__(attr), BaseParameter):
113 | self.parameters.append(self.__getattribute__(attr))
114 | if isinstance(self.__getattribute__(attr), Parameter):
115 | p = self.__getattribute__(attr).value
116 | p.name = attr
117 | self.parameters.append(p)
118 |
119 | def show_parameters(self):
120 | return '\n'.join([str(p) for p in self.parameters])
121 |
122 | def _setup_data(self, data: OHLC):
123 | self.data = data
124 |
125 | def _create_series(self):
126 | for attr in dir(self):
127 | if isinstance(self.__getattribute__(attr), list):
128 | if attr not in ['_before_methods', '_step_methods', '_after_methods', 'parameters']:
129 | sys.modules['components.strategy.series'].Series(self.__getattribute__(attr))
130 | if isinstance(self.__getattribute__(attr), pd.Series):
131 | sys.modules['components.strategy.series'].Series(self.__getattribute__(attr).to_list())
132 |
133 | def _get_all_series_data(self):
134 | series = []
135 | for attr in dir(self):
136 | if isinstance(self.__getattribute__(attr), sys.modules['components.strategy.series'].Series):
137 | series.append(self.__getattribute__(attr))
138 | return series
139 |
140 | def _get_all_plots(self):
141 | # will use in the future to get all plots
142 | plots = []
143 | for attr in dir(self):
144 | if isinstance(self.__getattribute__(attr), Plot):
145 | plots.append(self.__getattribute__(attr))
146 | return plots
147 |
148 | def run(self, data: OHLC, parameters: dict = None, **kwargs):
149 |
150 | # set parameters
151 | if parameters is not None:
152 | for p in self.parameters:
153 | if p.name in parameters:
154 | p.value = parameters[p.name]
155 |
156 | self.__init__(data=data)
157 | self._setup_data(data)
158 | self._create_series()
159 |
160 | # run before methods
161 | for method in self._before_methods:
162 | getattr(self, method)()
163 |
164 | # run step methods
165 |
166 | series = self._get_all_series_data()
167 |
168 | for i in range(len(self.data.dataframe)):
169 | for method in self._step_methods:
170 | getattr(self, method)()
171 |
172 | # advance the index of all series
173 | for s in series:
174 | s.advance_index()
175 |
176 | # advance the index of the data
177 | self.data.advance_index()
178 |
179 | # handle backtest
180 | b = Backtest(strategy=self, data=data)
181 |
182 | # runs the backtest
183 | if self.positions.positions or self.orders.orders:
184 | b.test()
185 |
186 | # run after methods
187 | for method in self._after_methods:
188 | getattr(self, method)()
189 |
190 | # get all plots
191 | plots = self.plots
192 |
193 | if kwargs.get('plots', False):
194 | logger.debug(f'Requested plots, found {len(plots)}')
195 | return b.result, plots
196 | return b.result
197 |
--------------------------------------------------------------------------------
/app/components/ohlc/ohlc.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import random
3 | from typing import Union
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from loguru import logger
8 |
9 | from components.ohlc.symbol import Symbol
10 |
11 | EMPTY_DATA = pd.DataFrame(columns=['open', 'high', 'low', 'close', 'volume', 'timestamp'])
12 | EMPTY_DATA.set_index('timestamp', inplace=True)
13 |
14 |
15 | class FutureTimestampRequested(Exception):
16 | def __init__(self, dataframe, index):
17 | super().__init__(f'Future timestamp requested. You must ensure the OHLCs resolution is set for future '
18 | f'timestamp extrapolation.\n'
19 | f'Dataframe length: {len(dataframe)}, index: {index}\n')
20 |
21 |
22 | class OHLC:
23 | """
24 | OHLCV data class. This class is used to store OHLCV data.
25 | Wraps around a pandas dataframe.
26 | """
27 |
28 | def __init__(
29 | self,
30 | symbol: Symbol = None,
31 | dataframe: pd.DataFrame = None,
32 | resolution: Union[int, None] = None,
33 | ):
34 | self.symbol = symbol
35 |
36 | # if no dataframe is provided, use an empty dataframe
37 | if dataframe is None:
38 | dataframe = EMPTY_DATA
39 | self.dataframe = dataframe
40 |
41 | # if no resolution is provided, attempt to interpret it
42 | if resolution is None:
43 | try:
44 | self._interpret_resolution()
45 | except ValueError:
46 | logger.error(f'Unable to interpret resolution for {self.symbol}')
47 | self.resolution = 0
48 | else:
49 | self.resolution = resolution
50 | self._index = 0
51 |
52 | # if a dataframe is provided, validate it
53 | if dataframe is not None:
54 | self._validate()
55 |
56 | def _interpret_resolution(self):
57 | """Attempts to interpret the resolution of the OHLC data."""
58 | # generate a sample size that is 25% of the data
59 | sample_size = len(self.dataframe) // 4
60 |
61 | # generate a set of random indexes to sample
62 | indexes = random.sample(range(len(self.dataframe) - 1), sample_size)
63 |
64 | # take random sample pairs and calculate the difference between timestamps
65 | diffs = [self.dataframe.index[i + 1] - self.dataframe.index[i] for i in indexes]
66 |
67 | # get the most common difference, convert to minutes
68 | self.resolution = int(max(set(diffs), key=diffs.count) / 1000 / 60)
69 |
70 | def advance_index(self, n: int = 1):
71 | self._index += n
72 |
73 | def reset_index(self):
74 | self._index = 0
75 |
76 | def _get_ohlc(self, column: str, index: int = None):
77 | if index is None:
78 | index = self._index
79 | try:
80 | value = self.dataframe[column].iat[index]
81 | except IndexError:
82 | logger.error(f'Index out of range. Index: {index}, Length: {len(self.dataframe)}')
83 | value = None
84 | return value
85 |
86 | @property
87 | def open(self):
88 | return self._get_ohlc('open')
89 |
90 | @property
91 | def high(self):
92 | return self._get_ohlc('high')
93 |
94 | @property
95 | def low(self):
96 | return self._get_ohlc('low')
97 |
98 | @property
99 | def close(self):
100 | return self._get_ohlc('close')
101 |
102 | @property
103 | def volume(self):
104 | return self.dataframe['volume']
105 |
106 | @property
107 | def timestamp(self):
108 | return self.dataframe.index[self._index]
109 |
110 | def get_timestamp(self, offset: int = 0):
111 | try:
112 | return self.dataframe.index[self._index + offset]
113 | except IndexError:
114 | if self._index + offset == len(self.dataframe):
115 | raise FutureTimestampRequested(self.dataframe, self._index + offset)
116 |
117 | def all(self, column: str):
118 | try:
119 | return self.dataframe[column]
120 | except KeyError:
121 | return []
122 |
123 | def __str__(self):
124 | return f'OHLC: {self.symbol}'
125 |
126 | def __getattr__(self, item):
127 | # check if the attribute is part of the class
128 | if item in self.__dict__:
129 | return self.__dict__[item]
130 | # else, forward the attribute to the dataframe
131 | else:
132 | return getattr(self.dataframe, item)
133 |
134 | def _validate(self):
135 | # ensure data has the correct columns
136 | if not {'open', 'high', 'low', 'close', 'volume'}.issubset(self.dataframe.columns):
137 | raise ValueError(f'Invalid data. Missing columns. Expected: open, high, low, close, volume.')
138 |
139 | # ensure the index of data is 'timestamp'
140 | if self.dataframe.index.name != 'timestamp':
141 | raise ValueError(f'Invalid data. Index must be named "timestamp", not "{self.dataframe.index.name}".')
142 |
143 | @staticmethod
144 | def from_csv(path: str, symbol: str):
145 | """
146 | Loads data from a csv file.
147 | :param symbol: symbol for the data
148 | :param path: Path to csv file.
149 | :return: self
150 | """
151 |
152 | # strategy a symbol from the symbol string
153 | symbol = Symbol(symbol)
154 |
155 | # load the data from the csv file
156 | dataframe = pd.read_csv(path)
157 | dataframe.set_index('timestamp', inplace=True)
158 |
159 | # create and validate OHLC object
160 | ohlc = OHLC(symbol=symbol, dataframe=dataframe)
161 | ohlc._validate()
162 |
163 | return ohlc
164 |
165 | def trim(self, start: int, end: int):
166 | """
167 | Trims the OHLC data to the specified range.
168 | :param start: Start index.
169 | :param end: End index.
170 | :return: self
171 | """
172 |
173 | start_dt = datetime.datetime.fromtimestamp(start / 1000)
174 | end_dt = datetime.datetime.fromtimestamp(end / 1000)
175 |
176 | df = self.dataframe.copy()
177 | logger.debug(f'Trimming data from {start_dt} to {end_dt}')
178 | logger.debug(f'Original data length: {len(df)}')
179 | df = df.loc[start:end]
180 | self.dataframe = df
181 | self._validate()
182 | logger.debug(f'Trimmed data length: {len(df)}')
183 | return self
184 |
185 | def __len__(self):
186 | return len(self.dataframe)
187 |
188 | def to_dict(self):
189 | df = self.dataframe.copy()
190 | df['time'] = df.index
191 | return {
192 | 'symbol': self.symbol,
193 | 'resolution': self.resolution,
194 | 'data': df.to_dict(orient='records')
195 | }
196 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | app/tests/test_api_adapter.py
2 |
3 | .wiki
4 |
5 | # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm
6 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm
7 |
8 | ### PyCharm ###
9 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
10 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
11 |
12 | # User-specific stuff
13 | .idea/**/workspace.xml
14 | .idea/**/tasks.xml
15 | .idea/**/usage.statistics.xml
16 | .idea/**/dictionaries
17 | .idea/**/shelf
18 |
19 | # AWS User-specific
20 | .idea/**/aws.xml
21 |
22 | # Generated files
23 | .idea/**/contentModel.xml
24 |
25 | # Sensitive or high-churn files
26 | .idea/**/dataSources/
27 | .idea/**/dataSources.ids
28 | .idea/**/dataSources.local.xml
29 | .idea/**/sqlDataSources.xml
30 | .idea/**/dynamic.xml
31 | .idea/**/uiDesigner.xml
32 | .idea/**/dbnavigator.xml
33 |
34 | # Gradle
35 | .idea/**/gradle.xml
36 | .idea/**/libraries
37 |
38 | # Gradle and Maven with auto-import
39 | # When using Gradle or Maven with auto-import, you should exclude module files,
40 | # since they will be recreated, and may cause churn. Uncomment if using
41 | # auto-import.
42 | # .idea/artifacts
43 | # .idea/compiler.xml
44 | # .idea/jarRepositories.xml
45 | # .idea/modules.xml
46 | # .idea/*.iml
47 | # .idea/modules
48 | # *.iml
49 | # *.ipr
50 |
51 | # CMake
52 | cmake-build-*/
53 |
54 | # Mongo Explorer plugin
55 | .idea/**/mongoSettings.xml
56 |
57 | # File-based project format
58 | *.iws
59 |
60 | # IntelliJ
61 | out/
62 |
63 | # mpeltonen/sbt-idea plugin
64 | .idea_modules/
65 |
66 | # JIRA plugin
67 | atlassian-ide-plugin.xml
68 |
69 | # Cursive Clojure plugin
70 | .idea/replstate.xml
71 |
72 | # SonarLint plugin
73 | .idea/sonarlint/
74 |
75 | # Crashlytics plugin (for Android Studio and IntelliJ)
76 | com_crashlytics_export_strings.xml
77 | crashlytics.properties
78 | crashlytics-build.properties
79 | fabric.properties
80 |
81 | # Editor-based Rest Client
82 | .idea/httpRequests
83 |
84 | # Android studio 3.1+ serialized cache file
85 | .idea/caches/build_file_checksums.ser
86 |
87 | ### PyCharm Patch ###
88 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
89 |
90 | # *.iml
91 | # modules.xml
92 | # .idea/misc.xml
93 | # *.ipr
94 |
95 | # Sonarlint plugin
96 | # https://plugins.jetbrains.com/plugin/7973-sonarlint
97 | .idea/**/sonarlint/
98 |
99 | # SonarQube Plugin
100 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
101 | .idea/**/sonarIssues.xml
102 |
103 | # Markdown Navigator plugin
104 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
105 | .idea/**/markdown-navigator.xml
106 | .idea/**/markdown-navigator-enh.xml
107 | .idea/**/markdown-navigator/
108 |
109 | # Cache file creation bug
110 | # See https://youtrack.jetbrains.com/issue/JBR-2257
111 | .idea/$CACHE_FILE$
112 |
113 | # CodeStream plugin
114 | # https://plugins.jetbrains.com/plugin/12206-codestream
115 | .idea/codestream.xml
116 |
117 | # Azure Toolkit for IntelliJ plugin
118 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij
119 | .idea/**/azureSettings.xml
120 |
121 | ### Python ###
122 | # Byte-compiled / optimized / DLL files
123 | __pycache__/
124 | *.py[cod]
125 | *$py.class
126 |
127 | # C extensions
128 | *.so
129 |
130 | # Distribution / packaging
131 | .Python
132 | build/
133 | develop-eggs/
134 | dist/
135 | downloads/
136 | eggs/
137 | .eggs/
138 | lib/
139 | lib64/
140 | parts/
141 | sdist/
142 | var/
143 | wheels/
144 | share/python-wheels/
145 | *.egg-info/
146 | .installed.cfg
147 | *.egg
148 | MANIFEST
149 |
150 | # PyInstaller
151 | # Usually these files are written by a python script from a template
152 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
153 | *.manifest
154 | *.spec
155 |
156 | # Installer logs
157 | pip-log.txt
158 | pip-delete-this-directory.txt
159 |
160 | # Unit test / coverage reports
161 | htmlcov/
162 | .tox/
163 | .nox/
164 | .coverage
165 | .coverage.*
166 | .cache
167 | nosetests.xml
168 | coverage.xml
169 | *.cover
170 | *.py,cover
171 | .hypothesis/
172 | .pytest_cache/
173 | cover/
174 |
175 | # Translations
176 | *.mo
177 | *.pot
178 |
179 | # Django stuff:
180 | *.log
181 | local_settings.py
182 | db.sqlite3
183 | db.sqlite3-journal
184 |
185 | # Flask stuff:
186 | instance/
187 | .webassets-cache
188 |
189 | # Scrapy stuff:
190 | .scrapy
191 |
192 | # Sphinx documentation
193 | docs/_build/
194 |
195 | # PyBuilder
196 | .pybuilder/
197 | target/
198 |
199 | # Jupyter Notebook
200 | .ipynb_checkpoints
201 |
202 | # IPython
203 | profile_default/
204 | ipython_config.py
205 |
206 | # pyenv
207 | # For a library or package, you might want to ignore these files since the code is
208 | # intended to run in multiple environments; otherwise, check them in:
209 | # .python-version
210 |
211 | # pipenv
212 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
213 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
214 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
215 | # install all needed dependencies.
216 | #Pipfile.lock
217 |
218 | # poetry
219 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
220 | # This is especially recommended for binary packages to ensure reproducibility, and is more
221 | # commonly ignored for libraries.
222 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
223 | #poetry.lock
224 |
225 | # pdm
226 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
227 | #pdm.lock
228 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
229 | # in version control.
230 | # https://pdm.fming.dev/#use-with-ide
231 | .pdm.toml
232 |
233 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
234 | __pypackages__/
235 |
236 | # Celery stuff
237 | celerybeat-schedule
238 | celerybeat.pid
239 |
240 | # SageMath parsed files
241 | *.sage.py
242 |
243 | # Environments
244 | .env
245 | .venv
246 | env/
247 | venv/
248 | ENV/
249 | env.bak/
250 | venv.bak/
251 |
252 | # Spyder project settings
253 | .spyderproject
254 | .spyproject
255 |
256 | # Rope project settings
257 | .ropeproject
258 |
259 | # mkdocs documentation
260 | /site
261 |
262 | # mypy
263 | .mypy_cache/
264 | .dmypy.json
265 | dmypy.json
266 |
267 | # Pyre type checker
268 | .pyre/
269 |
270 | # pytype static type analyzer
271 | .pytype/
272 |
273 | # Cython debug symbols
274 | cython_debug/
275 |
276 | # PyCharm
277 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
278 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
279 | # and can be added to the global gitignore or merged into this file. For a more nuclear
280 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
281 | #.idea/
282 |
283 | ### Python Patch ###
284 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
285 | poetry.toml
286 |
287 | # ruff
288 | .ruff_cache/
289 |
290 | # LSP config files
291 | pyrightconfig.json
292 |
293 | # End of https://www.toptal.com/developers/gitignore/api/python,pycharm
294 |
--------------------------------------------------------------------------------
/app/tests/test_backtest.py:
--------------------------------------------------------------------------------
1 | from components.backtest.backtest import Backtest
2 | from components.ohlc import CSVAdapter
3 | from components.orders import Order, LimitOrder, StopOrder
4 | from components.orders.enums import OrderSide
5 | from components.positions import Position
6 | from storage.strategies.examples.sma_cross_over import SMACrossOver
7 |
8 | OHLC = CSVAdapter().get_data(path='tests/data/AAPL.csv', symbol='AAPL')
9 | STRATEGY = SMACrossOver(data=OHLC)
10 |
11 |
12 | class TestBacktest:
13 | def test_backtest_orders(self):
14 | # add orders
15 | STRATEGY.orders.market_order(side='buy', quantity=1)
16 |
17 | backtest = Backtest(strategy=STRATEGY, data=OHLC)
18 | backtest.test()
19 |
20 | assert backtest.result is not None
21 |
22 | def test_backtest_positions(self):
23 | QTY = 1
24 |
25 | # add positions
26 | STRATEGY.positions.open(order_type='market', side='buy', quantity=QTY)
27 | entry_price, opened_timestamp = STRATEGY.data.close, (STRATEGY.data.timestamp + 300000) # add 5 minutes
28 | # timestamp
29 | STRATEGY.data.advance_index(100)
30 | STRATEGY.positions.close()
31 | exit_price, closed_timestamp = STRATEGY.data.close, STRATEGY.data.timestamp
32 |
33 | # create and run backtest
34 | backtest = Backtest(strategy=STRATEGY, data=OHLC)
35 | backtest.test()
36 |
37 | # grab the newly tested position
38 | p = STRATEGY.positions.all()[0]
39 |
40 | # calculate expected profit
41 | expected_profit = (exit_price - entry_price) * QTY
42 |
43 | # check position
44 | assert p.size == 0
45 | assert p.pnl == expected_profit
46 | assert p.unrealized_pnl is None
47 | assert p.average_entry_price == entry_price
48 | assert p.average_exit_price == exit_price
49 | assert p.opened_timestamp == opened_timestamp
50 | assert p.closed_timestamp == closed_timestamp
51 | assert backtest.result is not None
52 |
53 | def test_bracket_position(self):
54 | STRATEGY.data.reset_index()
55 | STRATEGY.data.advance_index(100)
56 | STRATEGY.orders.orders = []
57 | root_order = STRATEGY.orders.market_order(side='buy', quantity=1)
58 | STRATEGY.orders.limit_order(side='sell', quantity=1, price=root_order.price + 2)
59 | STRATEGY.orders.stop_loss_order(side='sell', quantity=1, price=root_order.price - 10)
60 |
61 | # create position
62 | p = Position(orders=STRATEGY.orders.all())
63 | p.test(ohlc=OHLC)
64 |
65 | # check position
66 | assert p.size == 0
67 | assert p.pnl == 1.9999999999999716
68 |
69 | def test_stop_loss(self):
70 | strategy = SMACrossOver(data=OHLC)
71 | strategy.orders.orders = []
72 | strategy.data.reset_index()
73 | strategy.data.advance_index(100)
74 |
75 | # add orders
76 | root_order = strategy.orders.market_order(side='buy', quantity=1)
77 | strategy.orders.stop_loss_order(side='sell', quantity=1, price=root_order.price - 5)
78 | strategy.orders.limit_order(side='sell', quantity=1, price=root_order.price + 100)
79 |
80 | # create position
81 | p = Position(orders=strategy.orders.all())
82 | p.test(ohlc=OHLC)
83 |
84 | # check position
85 | assert p.size == 0
86 | assert p.pnl == -5
87 |
88 | def test_overview_long_orders(self):
89 | strategy = SMACrossOver(data=OHLC)
90 | strategy.orders.orders = []
91 | strategy.data.reset_index()
92 | strategy.data.advance_index(100)
93 |
94 | # create positions
95 | for i in range(3):
96 | strategy.data.advance_index(5)
97 | strategy.positions.open(order_type='market', side='buy', quantity=1)
98 | strategy.data.advance_index(5)
99 | strategy.positions.close()
100 |
101 | # create backtest
102 | backtest = Backtest(strategy=strategy, data=OHLC)
103 | backtest.test()
104 |
105 | # check overview
106 | assert backtest.result.pnl == -5.170000000000016
107 |
108 | def test_overview_short_orders(self):
109 | # now do the same with short orders
110 | strategy = SMACrossOver(data=OHLC)
111 | strategy.orders.orders = []
112 | strategy.data.reset_index()
113 | strategy.data.advance_index(100)
114 |
115 | # create positions
116 | for i in range(3):
117 | strategy.data.advance_index(5)
118 | strategy.positions.open(order_type='market', side='sell', quantity=1)
119 | strategy.data.advance_index(5)
120 | strategy.positions.close()
121 |
122 | # create backtest
123 | backtest = Backtest(strategy=strategy, data=OHLC)
124 | backtest.test()
125 |
126 | # check overview
127 | assert backtest.result.pnl == 5.170000000000016
128 |
129 | def test_complex_positions(self):
130 | b = Backtest(strategy=SMACrossOver(), data=OHLC)
131 |
132 | p1 = Position(
133 | orders=[
134 | Order(side='buy', symbol='AAPL', qty=1, order_type='market', filled_avg_price=100, timestamp=1),
135 | Order(side='sell', symbol='AAPL', qty=1, order_type='limit', filled_avg_price=150, timestamp=2),
136 | ]
137 | )
138 |
139 | p2 = Position(
140 | orders=[
141 | Order(side='sell', symbol='AAPL', qty=1, order_type='market', filled_avg_price=100, timestamp=10),
142 | Order(side='buy', symbol='AAPL', qty=1, order_type='stop', filled_avg_price=90, timestamp=20),
143 | ]
144 | )
145 |
146 | p3 = Position(
147 | orders=[
148 | Order(side='buy', symbol='AAPL', qty=1, order_type='market', filled_avg_price=100, timestamp=100),
149 | Order(side='sell', symbol='AAPL', qty=1, order_type='stop', filled_avg_price=110, timestamp=200),
150 | ]
151 | )
152 |
153 | b.strategy.positions.add(p1)
154 | b.strategy.positions.add(p2)
155 | b.strategy.positions.add(p3)
156 | b.test()
157 |
158 | assert b.result.pnl == 70
159 |
160 | def test_backtest_short_with_order_types(self):
161 |
162 | b = Backtest(strategy=SMACrossOver(), data=OHLC)
163 |
164 | side = OrderSide.SELL
165 | open_order = Order(
166 | type='market',
167 | side=side,
168 | qty=100,
169 | symbol='AAPL',
170 | filled_avg_price=257.33,
171 | timestamp=1653984000000,
172 | filled_timestamp=1653984000000,
173 | )
174 | take_profit = LimitOrder(
175 | side=OrderSide.inverse(side),
176 | qty=100,
177 | symbol='AAPL',
178 | limit_price=256,
179 | )
180 | stop_loss = StopOrder(
181 | side=OrderSide.inverse(side),
182 | qty=100,
183 | symbol='AAPL',
184 | stop_price=300
185 | )
186 |
187 | p = Position(orders=[open_order, take_profit, stop_loss])
188 | b.strategy.positions.add(p)
189 | b.test()
190 | tested_position = b.strategy.positions.all()[0]
191 | print(tested_position)
192 |
193 | assert tested_position.size == 0
194 | assert tested_position.closed_timestamp == 1653984600000
195 |
196 | def test_backtest_long_with_order_types(self):
197 | b = Backtest(strategy=SMACrossOver(), data=OHLC)
198 | side = OrderSide.BUY
199 | open_order = Order(
200 | type='market',
201 | side=side,
202 | qty=100,
203 | symbol='AAPL',
204 | filled_avg_price=257.33,
205 | timestamp=1653984000000,
206 | filled_timestamp=1653984000000,
207 | )
208 | take_profit = LimitOrder(
209 | side=OrderSide.inverse(side),
210 | qty=100,
211 | symbol='AAPL',
212 | limit_price=260,
213 | )
214 | stop_loss = StopOrder(
215 | side=OrderSide.inverse(side),
216 | qty=100,
217 | symbol='AAPL',
218 | stop_price=100
219 | )
220 |
221 | p = Position(orders=[open_order, take_profit, stop_loss])
222 | b.strategy.positions.add(p)
223 | b.test()
224 | tested_position = b.strategy.positions.all()[0]
225 | assert tested_position.size == 0
226 | assert tested_position.closed_timestamp == 1654183500000
227 |
--------------------------------------------------------------------------------
/app/components/positions/positions.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | from typing import Optional, List, Union
3 |
4 | import pandas as pd
5 | from loguru import logger
6 | from pydantic import BaseModel
7 |
8 | from components.orders.order import Order, StopOrder, LimitOrder
9 | from components.positions.enums import PositionEffect
10 | from components.positions.exceptions import PositionUnbalancedException, PositionClosedException
11 | from components.positions.utils import get_effect
12 |
13 |
14 | def binary_search(df, target, is_buy):
15 | low, high = 0, len(df) - 1
16 |
17 | while low <= high:
18 | mid = (low + high) // 2
19 | mid_value = df.iloc[mid]
20 |
21 | if (is_buy and mid_value.low <= target) or (not is_buy and mid_value.high >= target):
22 | if mid == 0 or (is_buy and df.iloc[mid - 1].low > target) or (
23 | not is_buy and df.iloc[mid - 1].high < target):
24 | return mid
25 | high = mid - 1
26 | else:
27 | low = mid + 1
28 |
29 | return None
30 |
31 |
32 | class Position(BaseModel):
33 | id: Optional[str] = None
34 | orders: List[Order] = []
35 | closed: bool = False
36 | cost_basis: Optional[float] = 0
37 | average_entry_price: Optional[float] = None
38 | average_exit_price: Optional[float] = None
39 | size: Optional[int] = 0
40 | largest_size: Optional[int] = 0
41 | side: Optional[str] = None
42 | unrealized_pnl: Optional[float] = None
43 | pnl: Optional[float] = 0
44 | opened_timestamp: Optional[int] = None
45 | closed_timestamp: Optional[int] = None
46 | is_tested: bool = False
47 |
48 | def __int__(self):
49 | super().__init__()
50 | self.id = self._get_id()
51 |
52 | def __str__(self):
53 | return f'Position: {""}\t[{self.side.upper()}]\t{self.size} {self.average_entry_price} -> ' \
54 | f'{self.average_exit_price} \t pnl:{round(self.pnl, 2)} (opn:' \
55 | f' {self.opened_timestamp}, cls: {self.closed_timestamp})'
56 |
57 | def get_size(self):
58 | """Get the size of all orders in the position."""
59 | return sum([o.qty for o in self.orders])
60 |
61 | def get_side(self):
62 | """Returns the side of the position, either buy or sell."""
63 | if self.side is None:
64 | return self.orders[0].side
65 | return self.side
66 |
67 | def _get_id(self):
68 | """Get the id of the position."""
69 | order_ids = [o.id for o in self.orders]
70 | return hashlib.md5(str(order_ids).encode()).hexdigest()
71 |
72 | def _get_root_side_orders(self):
73 | """Get the root side orders of the position."""
74 | return [o for o in self.orders if o.side == self.side]
75 |
76 | def _update_average_entry(self, order: Order):
77 | if self.average_exit_price:
78 | self.average_exit_price = ((self.average_exit_price + order.filled_avg_price) / 2)
79 | else:
80 | self.average_exit_price = order.filled_avg_price
81 |
82 | def _update_pnl(self, order: Order):
83 | difference = order.filled_avg_price - self.average_entry_price
84 | multiplier = -1 if self.get_side() == 'sell' else 1
85 | realized_pnl = (difference * abs(order.qty)) * multiplier
86 | self.pnl += realized_pnl
87 |
88 | def _update_largest_size(self):
89 | if abs(self.size) > self.largest_size:
90 | self.largest_size = abs(self.size)
91 |
92 | def _update_opened_timestamp(self, order: Order):
93 | if self.opened_timestamp is None:
94 | self.opened_timestamp = order.timestamp
95 |
96 | def _add_order_to_size(self, order: Order):
97 | self.size += order.qty
98 |
99 | def _fill_order(self, order: Union[Order, StopOrder, LimitOrder], ohlc: 'OHLC' = None):
100 | """Handles TBD orders. Sets the timestamp and fills the order if it was filled."""
101 | start_index = ohlc.index.get_loc(self.orders[0].timestamp)
102 | df = ohlc.dataframe.iloc[start_index + 1:]
103 |
104 | if order.type == 'stop':
105 | filled_order = self._process_stop_order(order, df)
106 | elif order.type == 'limit':
107 | filled_order = self._process_limit_order(order, df)
108 |
109 | if filled_order is None:
110 | self._handle_unfilled_order(order)
111 | else:
112 | order.timestamp = order.filled_timestamp
113 |
114 | def _process_stop_order(self, order, df: pd.DataFrame):
115 | condition = (df.low <= order.stop_price) if self.side == 'buy' else (df.high >= order.stop_price)
116 | filtered_df = df[condition]
117 |
118 | if not filtered_df.empty:
119 | order.filled_timestamp = filtered_df.index[0]
120 | order.filled_avg_price = order.stop_price
121 | return order
122 | return None
123 |
124 | def _process_limit_order(self, order, df: pd.DataFrame):
125 | condition = (df.high >= order.limit_price) if self.side == 'buy' else (df.low <= order.limit_price)
126 | filtered_df = df[condition]
127 |
128 | if not filtered_df.empty:
129 | order.filled_timestamp = filtered_df.index[0]
130 | order.filled_avg_price = order.limit_price
131 | return order
132 | return None
133 |
134 | def _handle_unfilled_order(self, order):
135 | logger.warning(f'Order {order.get_id()} was never filled')
136 | logger.warning(f'{self}')
137 | for order in self.orders:
138 | logger.warning(f'\t{order}')
139 | # set order did not fill to true
140 | order.did_not_fill = True
141 | return
142 |
143 | def handle_order(self, order: Union[Order, StopOrder, LimitOrder], ohlc: 'OHLC' = None):
144 |
145 | # if order isn't already filled, it cannot be handled yet
146 | if order.filled_timestamp is None:
147 | return
148 |
149 | # if the position is closed, we can't handle any more orders
150 | if self.closed:
151 | raise PositionClosedException('Position is already closed')
152 |
153 | # if the position is missing a side, set it to the side of the first order
154 | if self.side is None:
155 | self.side = order.side
156 |
157 | # gets the position effect, either add or reduce
158 | effect = get_effect(position=self, order=order)
159 |
160 | if effect == PositionEffect.ADD:
161 | # since the position is added, we need to calculate the cost basis
162 | self.cost_basis += order.filled_avg_price * order.qty
163 | # adjust the average entry price
164 | self.average_entry_price = self.cost_basis / (self.size + order.qty)
165 |
166 | if effect == PositionEffect.REDUCE:
167 | # check that the position will not change sides
168 | if self.side == 'buy' and self.size + order.qty < 0:
169 | raise PositionUnbalancedException('Position will change sides')
170 | if self.side == 'sell' and self.size + order.qty > 0:
171 | raise PositionUnbalancedException('Position will change sides')
172 |
173 | # since the position is reduced, we need to re-calculate the average entry price and update the PNL
174 | self._update_average_entry(order)
175 | self._update_pnl(order)
176 |
177 | # adjust the size, if the size is 0, the position is closed
178 | self._add_order_to_size(order)
179 | self._update_largest_size()
180 | self._update_opened_timestamp(order)
181 |
182 | # set closed, update closed timestamp if closed
183 | self.closed = self.size == 0
184 | if self.closed:
185 | self.closed_timestamp = order.timestamp
186 |
187 | # if order is missing filled timestamp, set it to the order timestamp
188 | if order.filled_timestamp is None:
189 | order.filled_timestamp = order.timestamp
190 |
191 | def test(self, ohlc: 'OHLC' = None):
192 | """Backtest the position."""
193 | # handle all orders with a filled timestamp, as these are already filled
194 | filled_orders = [o for o in self.orders if o.filled_timestamp is not None]
195 | for order in filled_orders:
196 | self.handle_order(order=order, ohlc=ohlc)
197 |
198 | # if there are still working orders, handle them
199 | working_orders = [o for o in self.orders if o.filled_timestamp is None]
200 | if len(working_orders) > 0:
201 |
202 | # handle all orders without a filled timestamp, as these are TBD
203 | for order in working_orders:
204 | self._fill_order(order=order, ohlc=ohlc)
205 |
206 | # determine which order was filled first, filter out any orders that were never filled
207 | filled_working_orders = [o for o in working_orders if o.filled_timestamp is not None]
208 | sorted_orders = sorted(filled_working_orders, key=lambda o: o.timestamp)
209 |
210 | # handle the first order
211 | first_order = sorted_orders[0]
212 | self.handle_order(order=first_order, ohlc=ohlc)
213 |
214 | # set the remaining order's filled_timestamp to None
215 | for order in sorted_orders[1:]:
216 | order.filled_timestamp = None
217 |
218 |
219 | class BracketPosition(Position):
220 | def __init__(self):
221 | super().__init__()
222 | self.stop_order: Union[StopOrder, None] = None
223 | self.limit_order: Union[LimitOrder, None] = None
224 |
--------------------------------------------------------------------------------