├── data ├── __init__.py ├── log │ └── .gitkeep ├── agents │ └── .gitkeep ├── reports │ └── .gitkeep ├── .DS_Store └── params.db ├── lib ├── __init__.py ├── data │ ├── __init__.py │ ├── features │ │ ├── __init__.py │ │ ├── transform.py │ │ └── indicators.py │ └── providers │ │ ├── dates │ │ ├── __init__.py │ │ └── ProviderDateFormat.py │ │ ├── __init__.py │ │ ├── StaticDataProvider.py │ │ ├── BaseDataProvider.py │ │ └── ExchangeDataProvider.py ├── util │ ├── __init__.py │ ├── logger.py │ └── benchmarks.py ├── env │ ├── render │ │ ├── __init__.py │ │ └── TradingChart.py │ ├── __init__.py │ ├── trade │ │ ├── __init__.py │ │ ├── BaseTradeStrategy.py │ │ ├── LiveTradeStrategy.py │ │ └── SimulatedTradeStrategy.py │ ├── reward │ │ ├── __init__.py │ │ ├── BaseRewardStrategy.py │ │ ├── IncrementalProfit.py │ │ └── WeightedUnrealizedProfit.py │ └── TradingEnv.py ├── __init__.pyc ├── cli │ ├── functions │ │ ├── __init__.py │ │ └── update_data.py │ ├── __init__.py │ └── RLTraderCLI.py └── RLTrader.py ├── test ├── __init__.py ├── data │ ├── __init__.py │ └── test_providers.py └── test_rl_trader.py ├── config └── config.ini.dist ├── requirements.tests.txt ├── requirements.txt ├── requirements.no-gpu.txt ├── visualization.gif ├── .travis.yml ├── .dockerignore ├── .gitignore ├── requirements.base.txt ├── docker ├── Dockerfile.gpu ├── Dockerfile.backend ├── Dockerfile.cpu └── Dockerfile.tests ├── .github └── FUNDING.yml ├── run-tests-with-docker ├── optimize.py ├── run-with-docker ├── dev-with-docker ├── cli.py ├── Vagrantfile ├── Experiments.ipynb ├── README.md └── LICENSE /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/log/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/agents/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/reports/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /lib/data/features/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/config.ini.dist: -------------------------------------------------------------------------------- 1 | [Defaults] 2 | mini-batches=11 -------------------------------------------------------------------------------- /requirements.tests.txt: -------------------------------------------------------------------------------- 1 | -r requirements.no-gpu.txt 2 | pytest -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements.base.txt 2 | tensorflow-gpu -------------------------------------------------------------------------------- /requirements.no-gpu.txt: -------------------------------------------------------------------------------- 1 | -r requirements.base.txt 2 | tensorflow 3 | -------------------------------------------------------------------------------- /lib/env/render/__init__.py: -------------------------------------------------------------------------------- 1 | from lib.env.render.TradingChart import TradingChart 2 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notadamking/RLTrader/HEAD/data/.DS_Store -------------------------------------------------------------------------------- /data/params.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notadamking/RLTrader/HEAD/data/params.db -------------------------------------------------------------------------------- /lib/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notadamking/RLTrader/HEAD/lib/__init__.pyc -------------------------------------------------------------------------------- /visualization.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notadamking/RLTrader/HEAD/visualization.gif -------------------------------------------------------------------------------- /lib/cli/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from lib.cli.functions.update_data import download_data_async 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | 3 | services: 4 | - docker 5 | 6 | script: bash run-tests-with-docker -------------------------------------------------------------------------------- /lib/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from lib.cli.functions import * 2 | from lib.cli.RLTraderCLI import RLTraderCLI 3 | -------------------------------------------------------------------------------- /lib/data/providers/dates/__init__.py: -------------------------------------------------------------------------------- 1 | from lib.data.providers.dates.ProviderDateFormat import ProviderDateFormat 2 | -------------------------------------------------------------------------------- /lib/env/__init__.py: -------------------------------------------------------------------------------- 1 | from lib.env.TradingEnv import TradingEnv 2 | from lib.env.render.TradingChart import TradingChart 3 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vagrant 3 | research 4 | .git 5 | tensorboard 6 | agents 7 | data/tensorboard 8 | data/agents 9 | data/postgres 10 | data/reports -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .ipynb_checkpoints 3 | .pytest_cache 4 | **/__pycache__ 5 | data/tensorboard/* 6 | data/agents/* 7 | data/postgres/* 8 | data/log/* 9 | data/reports/* 10 | *.pkl 11 | venv/* 12 | -------------------------------------------------------------------------------- /lib/env/trade/__init__.py: -------------------------------------------------------------------------------- 1 | from lib.env.trade.BaseTradeStrategy import BaseTradeStrategy 2 | from lib.env.trade.LiveTradeStrategy import LiveTradeStrategy 3 | from lib.env.trade.SimulatedTradeStrategy import SimulatedTradeStrategy 4 | -------------------------------------------------------------------------------- /requirements.base.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | sklearn 4 | matplotlib 5 | gym 6 | stable_baselines 7 | optuna 8 | ta 9 | statsmodels==0.10.0rc2 10 | empyrical 11 | ccxt 12 | psycopg2 13 | configparser 14 | quantstats>=0.0.17 15 | -------------------------------------------------------------------------------- /lib/env/reward/__init__.py: -------------------------------------------------------------------------------- 1 | from lib.env.reward.IncrementalProfit import IncrementalProfit 2 | from lib.env.reward.WeightedUnrealizedProfit import WeightedUnrealizedProfit 3 | from lib.env.reward.BaseRewardStrategy import BaseRewardStrategy 4 | -------------------------------------------------------------------------------- /docker/Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | FROM python:3.6.8-jessie 2 | 3 | ADD ./requirements.base.txt /code/ 4 | ADD ./requirements.txt /code/ 5 | 6 | WORKDIR /code 7 | 8 | RUN apt-get update \ 9 | && apt-get install -y build-essential mpich libpq-dev \ 10 | && pip install -r requirements.txt -------------------------------------------------------------------------------- /docker/Dockerfile.backend: -------------------------------------------------------------------------------- 1 | FROM postgres:11-alpine 2 | 3 | ARG ID=1000 4 | ARG GI=1000 5 | 6 | ENV POSTGRES_USER=rl_trader 7 | ENV POSTGRES_PASSWORD=rl_trader 8 | ENV POSTGRES_DB='rl_trader' 9 | ENV PGDATA=/var/lib/postgresql/data/trader-data 10 | 11 | RUN adduser -D -u $ID rl_trader -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: 'notadamking' 4 | patreon: 'notadamking' 5 | custom: ['https://www.blockchain.com/btc/address/1Lc47bhYvdyKGk1qN8oBHdYQTkbFLL3PFw', 'https://www.blockchain.com/eth/address/0x9907A0cF64Ec9Fbf6Ed8FD4971090DE88222a9aC'] 6 | -------------------------------------------------------------------------------- /lib/data/providers/__init__.py: -------------------------------------------------------------------------------- 1 | from lib.data.providers.dates.ProviderDateFormat import ProviderDateFormat 2 | 3 | from lib.data.providers.BaseDataProvider import BaseDataProvider 4 | from lib.data.providers.StaticDataProvider import StaticDataProvider 5 | from lib.data.providers.ExchangeDataProvider import ExchangeDataProvider 6 | -------------------------------------------------------------------------------- /lib/data/providers/dates/ProviderDateFormat.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ProviderDateFormat(Enum): 5 | TIMESTAMP_UTC = 1 6 | TIMESTAMP_MS = 2 7 | DATE = 3 8 | DATETIME_HOUR_12 = 4 9 | DATETIME_HOUR_24 = 5 10 | DATETIME_MINUTE_12 = 6 11 | DATETIME_MINUTE_24 = 7 12 | CUSTOM_DATIME = 8 13 | -------------------------------------------------------------------------------- /docker/Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | FROM python:3.6.8-jessie 2 | 3 | ADD ./requirements.base.txt /code/ 4 | ADD ./requirements.no-gpu.txt /code/requirements.txt 5 | 6 | WORKDIR /code 7 | 8 | RUN apt-get update \ 9 | && apt-get install -y build-essential mpich libpq-dev 10 | 11 | # should merge to top RUN to avoid extra layers - for debug only :/ 12 | RUN pip install -r requirements.txt -------------------------------------------------------------------------------- /docker/Dockerfile.tests: -------------------------------------------------------------------------------- 1 | FROM python:3.6.8-jessie 2 | 3 | ADD ./requirements.base.txt /code/ 4 | ADD ./requirements.no-gpu.txt /code/ 5 | ADD ./requirements.tests.txt /code/requirements.txt 6 | 7 | WORKDIR /code 8 | 9 | RUN apt-get update \ 10 | && apt-get install -y build-essential mpich libpq-dev \ 11 | && pip install --progress-bar off --requirement requirements.txt -------------------------------------------------------------------------------- /run-tests-with-docker: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | CWD="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 5 | 6 | N="trader-rl-cpu" 7 | docker build --tag $N -f docker/Dockerfile.tests "$CWD" 8 | 9 | docker run \ 10 | --user $(id -u):$(id -g) \ 11 | --interactive \ 12 | --tty \ 13 | --volume "${CWD}":/code \ 14 | "$N" \ 15 | python -m pytest -p no:warnings /code/test -------------------------------------------------------------------------------- /lib/env/reward/BaseRewardStrategy.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from abc import ABCMeta, abstractmethod 4 | from typing import List, Callable 5 | 6 | 7 | class BaseRewardStrategy(object, metaclass=ABCMeta): 8 | @abstractmethod 9 | def __init__(self): 10 | pass 11 | 12 | @abstractmethod 13 | def reset_reward(self): 14 | raise NotImplementedError() 15 | 16 | @abstractmethod 17 | def get_reward(self, 18 | current_step: int, 19 | current_price: Callable[[str], float], 20 | observations: pd.DataFrame, 21 | account_history: pd.DataFrame, 22 | net_worths: List[float]) -> float: 23 | raise NotImplementedError() 24 | -------------------------------------------------------------------------------- /lib/env/trade/BaseTradeStrategy.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Tuple, Callable 3 | 4 | 5 | class BaseTradeStrategy(object, metaclass=ABCMeta): 6 | @abstractmethod 7 | def __init__(self, 8 | commissionPercent: float, 9 | maxSlippagePercent: float, 10 | base_precision: int, 11 | asset_precision: int, 12 | min_cost_limit: float, 13 | min_amount_limit: float): 14 | pass 15 | 16 | @abstractmethod 17 | def trade(self, 18 | action: int, 19 | n_discrete_actions: int, 20 | balance: float, 21 | asset_held: float, 22 | current_price: Callable[[str], float]) -> Tuple[float, float, float, float]: 23 | raise NotImplementedError() 24 | -------------------------------------------------------------------------------- /optimize.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import numpy as np 3 | 4 | from multiprocessing import Pool 5 | 6 | from lib.env.reward import WeightedUnrealizedProfit 7 | 8 | np.warnings.filterwarnings('ignore') 9 | 10 | 11 | def optimize_code(params): 12 | from lib.RLTrader import RLTrader 13 | 14 | trader = RLTrader(**params) 15 | trader.optimize() 16 | 17 | return "" 18 | 19 | 20 | if __name__ == '__main__': 21 | n_processes = multiprocessing.cpu_count() 22 | params = {'n_envs': n_processes, 'reward_strategy': WeightedUnrealizedProfit} 23 | 24 | opt_pool = Pool(processes=n_processes) 25 | results = opt_pool.imap(optimize_code, [params for _ in range(n_processes)]) 26 | 27 | print([result.get() for result in results]) 28 | 29 | from lib.RLTrader import RLTrader 30 | 31 | trader = RLTrader(**params) 32 | trader.train(test_trained_model=True, render_test_env=True, render_report=True, save_report=True) 33 | -------------------------------------------------------------------------------- /lib/env/trade/LiveTradeStrategy.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Callable 2 | from enum import Enum 3 | 4 | from lib.env.trade import BaseTradeStrategy 5 | 6 | 7 | class LiveTradeStrategy(BaseTradeStrategy): 8 | def __init__(self, 9 | commissionPercent: float, 10 | maxSlippagePercent: float, 11 | base_precision: int, 12 | asset_precision: int, 13 | min_cost_limit: float, 14 | min_amount_limit: float): 15 | self.commissionPercent = commissionPercent 16 | self.maxSlippagePercent = maxSlippagePercent 17 | self.base_precision = base_precision 18 | self.asset_precision = asset_precision 19 | self.min_cost_limit = min_cost_limit 20 | self.min_amount_limit = min_amount_limit 21 | 22 | def trade(self, 23 | buy_amount: float, 24 | sell_amount: float, 25 | balance: float, 26 | asset_held: float, 27 | current_price: Callable[[str], float]) -> Tuple[float, float, float, float]: 28 | raise NotImplementedError() 29 | -------------------------------------------------------------------------------- /test/data/test_providers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from lib.data.providers.dates import ProviderDateFormat 4 | from lib.data.providers import StaticDataProvider 5 | 6 | 7 | @pytest.fixture 8 | def csv_provider(): 9 | data_columns = {'Date': 'Date', 'Open': 'Open', 'High': 'High', 10 | 'Low': 'Low', 'Close': 'Close', 'Volume': 'VolumeFrom'} 11 | provider = StaticDataProvider( 12 | date_format=ProviderDateFormat.DATETIME_HOUR_24, csv_data_path="data/input/coinbase-1h-btc-usd.csv", data_columns=data_columns 13 | ) 14 | 15 | assert csv_provider is not None 16 | 17 | return provider 18 | 19 | 20 | class TestPrepareData(): 21 | def test_column_map(self, csv_provider): 22 | ohlcv = csv_provider.historical_ohlcv() 23 | 24 | expected = ['Date', 'Open', 'High', 25 | 'Low', 'Close', 'Volume'] 26 | 27 | assert (ohlcv.columns == expected).all() 28 | 29 | def test_date_sort(self, csv_provider): 30 | ohlcv = csv_provider.historical_ohlcv() 31 | 32 | timestamps = ohlcv['Date'].values 33 | sorted_timestamps = sorted(timestamps.copy()) 34 | 35 | assert (timestamps == sorted_timestamps).all() 36 | -------------------------------------------------------------------------------- /lib/env/reward/IncrementalProfit.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from typing import List, Callable 4 | 5 | from lib.env.reward.BaseRewardStrategy import BaseRewardStrategy 6 | 7 | 8 | class IncrementalProfit(BaseRewardStrategy): 9 | last_bought: int = 0 10 | last_sold: int = 0 11 | 12 | def __init__(self): 13 | pass 14 | 15 | def reset_reward(self): 16 | pass 17 | 18 | def get_reward(self, 19 | current_step: int, 20 | current_price: Callable[[str], float], 21 | observations: pd.DataFrame, 22 | account_history: pd.DataFrame, 23 | net_worths: List[float]) -> float: 24 | reward = 0 25 | 26 | curr_balance = account_history['balance'].values[-1] 27 | prev_balance = account_history['balance'].values[-2] if len(account_history['balance']) > 1 else curr_balance 28 | 29 | if curr_balance > prev_balance: 30 | reward = net_worths[-1] - net_worths[self.last_bought] 31 | self.last_sold = current_step 32 | elif curr_balance < prev_balance: 33 | reward = observations['Close'].values[self.last_sold] - current_price() 34 | self.last_bought = current_step 35 | 36 | return reward 37 | -------------------------------------------------------------------------------- /test/test_rl_trader.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | from unittest.mock import MagicMock 3 | 4 | from lib.RLTrader import RLTrader 5 | from lib.cli.RLTraderCLI import RLTraderCLI 6 | 7 | 8 | class TestRLTrader(): 9 | def setup_class(self): 10 | self.parser = RLTraderCLI().get_parser() 11 | 12 | @mock.patch.object(RLTrader, 'initialize_data', return_value=True) 13 | @mock.patch.object(RLTrader, 'optimize', return_value=True) 14 | @mock.patch.object(RLTrader, 'initialize_optuna', return_value=True) 15 | def test_that_args_get_injected_correctly(self, data_mock, opt_mock, init_mock): 16 | args = self.parser.parse_args(['optimize']) 17 | sut = RLTrader(**vars(args), logger=MagicMock()) 18 | sut.study_name = 'test' 19 | with mock.patch('lib.util.logger.init_logger'): 20 | assert(sut.tensorboard_path == args.tensorboard_path) 21 | assert(sut.params_db_path == args.params_db_path) 22 | assert(sut.model_verbose == args.model_verbose) 23 | assert(sut.n_minibatches == args.n_minibatches) 24 | assert(sut.train_split_percentage == args.train_split_percentage) 25 | assert(sut.input_data_path == args.input_data_path) 26 | assert(sut.model_verbose == args.model_verbose) 27 | -------------------------------------------------------------------------------- /lib/env/reward/WeightedUnrealizedProfit.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import pandas as pd 4 | import numpy as np 5 | from typing import List, Callable 6 | 7 | from lib.env.reward.BaseRewardStrategy import BaseRewardStrategy 8 | 9 | 10 | class WeightedUnrealizedProfit(BaseRewardStrategy): 11 | def __init__(self, **kwargs): 12 | self.decay_rate = kwargs.get('decay_rate', 1e-2) 13 | self.decay_denominator = np.exp(-1 * self.decay_rate) 14 | 15 | self.reset_reward() 16 | 17 | def reset_reward(self): 18 | self.rewards = deque(np.zeros(1, dtype=float)) 19 | self.sum = 0.0 20 | 21 | def calc_reward(self, reward): 22 | self.sum = self.sum - self.decay_denominator * self.rewards.popleft() 23 | self.sum = self.sum * self.decay_denominator 24 | self.sum = self.sum + reward 25 | 26 | self.rewards.append(reward) 27 | 28 | return self.sum / self.decay_denominator 29 | 30 | def get_reward(self, 31 | current_step: int, 32 | current_price: Callable[[str], float], 33 | observations: pd.DataFrame, 34 | account_history: pd.DataFrame, 35 | net_worths: List[float]) -> float: 36 | if account_history['asset_sold'].values[-1] > 0: 37 | reward = self.calc_reward(account_history['sale_revenue'].values[-1]) 38 | else: 39 | reward = self.calc_reward(account_history['asset_held'].values[-1] * current_price()) 40 | 41 | return reward 42 | -------------------------------------------------------------------------------- /lib/cli/functions/update_data.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import ssl 3 | import pandas as pd 4 | import os 5 | 6 | final_date_format = '%Y-%m-%d %H:%M' 7 | ssl._create_default_https_context = ssl._create_unverified_context 8 | 9 | hourly_url = "https://www.cryptodatadownload.com/cdd/Coinbase_BTCUSD_1h.csv" 10 | daily_url = "https://www.cryptodatadownload.com/cdd/Coinbase_BTCUSD_d.csv" 11 | 12 | 13 | async def save_url_to_csv(url: str, date_format: str, file_name: str): 14 | csv = pd.read_csv(url, header=1) 15 | csv = csv.dropna(thresh=2) 16 | csv.columns = ['Date', 'Symbol', 'Open', 'High', 'Low', 'Close', 'VolumeFrom', 'VolumeTo'] 17 | csv['Date'] = pd.to_datetime(csv['Date'], format=date_format) 18 | csv['Date'] = csv['Date'].dt.strftime(final_date_format) 19 | 20 | final_path = os.path.join('data', 'input', file_name) 21 | csv.to_csv(final_path, index=False) 22 | 23 | return csv 24 | 25 | 26 | async def save_as_csv(hourly_url: str, daily_url: str): 27 | tasks = [save_url_to_csv(hourly_url, '%Y-%m-%d %I-%p', 'coinbase-1h-btc-usd.csv'), 28 | save_url_to_csv(daily_url, '%Y-%m-%d', 'coinbase-1d-btc-usd.csv')] 29 | # also FIRST_EXCEPTION and ALL_COMPLETED (default) 30 | done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) 31 | print('>> done: ', done) 32 | print('>> pending: ', pending) # will be empty if using default return_when setting 33 | 34 | 35 | def download_data_async(): 36 | loop = asyncio.get_event_loop() 37 | loop.run_until_complete(save_as_csv(hourly_url, daily_url)) 38 | loop.close() 39 | 40 | 41 | if __name__ == '__main__': 42 | download_data_async() 43 | -------------------------------------------------------------------------------- /run-with-docker: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | CWD="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 6 | 7 | if [[ -z $1 ]]; then 8 | echo "Should have 1 argument: cpu or gpu" 9 | exit 10 | fi 11 | 12 | TYPE=$1 13 | shift; 14 | 15 | if [[ -n $2 ]]; then 16 | docker build \ 17 | --tag 'trader-rl-postgres' \ 18 | --build-arg ID=$(id -u) \ 19 | --build-arg GI=$(id -g) \ 20 | -f "$CWD/docker/Dockerfile.backend" "$CWD" 21 | 22 | mkdir -p "$CWD/data/postgres" 23 | docker run \ 24 | --detach \ 25 | --publish 5432:5432 \ 26 | --tty \ 27 | --user "$(id -u):$(id -g)" \ 28 | --volume "$CWD/data/postgres":"/var/lib/postgresql/data/trader-data" \ 29 | trader-rl-postgres 30 | shift 31 | fi 32 | 33 | if [[ $TYPE == 'gpu' ]]; then 34 | GPU=1 35 | else 36 | GPU=0 37 | fi 38 | 39 | MEM=$(cat /proc/meminfo | grep 'MemTotal:' | awk '{ print $2 }') 40 | CPUS=$(cat /proc/cpuinfo | grep -P 'processor.+[0-7]+' | wc -l) 41 | 42 | MEM_LIMIT=$((MEM/4*3)) 43 | CPU_LIMIT=$((CPUS/4*3)) 44 | 45 | if [ $CPU_LIMIT == 0 ];then 46 | CPU_LIMIT=1 47 | fi 48 | 49 | if [ $GPU == 0 ]; then 50 | N="trader-rl-cpu" 51 | docker build --tag $N -f "$CWD/docker/Dockerfile.cpu" "$CWD" 52 | else 53 | N="trader-rl-gpu" 54 | docker build --tag $N -f "$CWD/docker/Dockerfile.gpu" "$CWD" 55 | fi 56 | 57 | echo "CWD: $CWD - Procs: $CPU_LIMIT Memory: ${MEM_LIMIT}bytes" 58 | docker run \ 59 | --user $(id -u):$(id -g) \ 60 | --interactive \ 61 | --memory "${MEM_LIMIT}b" \ 62 | --cpus "${CPU_LIMIT}" \ 63 | --tty \ 64 | --volume "${CWD}":/code \ 65 | "$N" \ 66 | python /code/cli.py $@ -------------------------------------------------------------------------------- /dev-with-docker: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") 6 | CWD=$(realpath "${SCRIPT_DIR}") 7 | 8 | if [[ -z $1 ]]; then 9 | echo "Should have 1 argument: cpu or gpu" 10 | exit 11 | fi 12 | 13 | TYPE=$1 14 | shift; 15 | 16 | if [[ -n $2 ]]; then 17 | docker build \ 18 | --tag 'trader-rl-postgres' \ 19 | --build-arg ID=$(id -u) \ 20 | --build-arg GI=$(id -g) \ 21 | -f "$CWD/docker/Dockerfile.backend" "$CWD" 22 | 23 | mkdir -p "$CWD/data/postgres" 24 | docker run \ 25 | --detach \ 26 | --publish 5432:5432 \ 27 | --tty \ 28 | --user "$(id -u):$(id -g)" \ 29 | --volume "$CWD/data/postgres":"/var/lib/postgresql/data/trader-data" \ 30 | trader-rl-postgres-dev 31 | shift 32 | fi 33 | 34 | if [[ $TYPE == 'gpu' ]]; then 35 | GPU=1 36 | else 37 | GPU=0 38 | fi 39 | 40 | MEM=$(cat /proc/meminfo | grep 'MemTotal:' | awk '{ print $2 }') 41 | CPUS=$(cat /proc/cpuinfo | grep -P 'processor.+[0-7]+' | wc -l) 42 | 43 | MEM_LIMIT=$((MEM/4*3)) 44 | CPU_LIMIT=$((CPUS/4*3)) 45 | 46 | if [ $CPU_LIMIT == 0 ];then 47 | CPU_LIMIT=1 48 | fi 49 | 50 | if [ $GPU == 0 ]; then 51 | N="trader-rl-cpu-dev" 52 | docker build --tag $N -f "$CWD/docker/Dockerfile.cpu" "$CWD" 53 | else 54 | N="trader-rl-gpu-dev" 55 | docker build --tag $N -f "$CWD/docker/Dockerfile.gpu" "$CWD" 56 | fi 57 | 58 | docker rm -fv rl_trader_dev || true 59 | docker run \ 60 | --name 'rl_trader_dev' \ 61 | --user $(id -u):$(id -g) \ 62 | --entrypoint 'bash' \ 63 | --interactive \ 64 | --memory "${MEM_LIMIT}b" \ 65 | --cpus "${CPU_LIMIT}" \ 66 | --tty \ 67 | --volume "${CWD}":/code \ 68 | "$N" -------------------------------------------------------------------------------- /lib/util/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import colorlog 4 | 5 | def init_logger(dunder_name, show_debug=False) -> logging.Logger: 6 | log_format = ( 7 | '%(asctime)s - ' 8 | '%(name)s - ' 9 | '%(funcName)s - ' 10 | '%(levelname)s - ' 11 | '%(message)s' 12 | ) 13 | bold_seq = '\033[1m' 14 | colorlog_format = ( 15 | f'{bold_seq} ' 16 | '%(log_color)s ' 17 | f'{log_format}' 18 | ) 19 | colorlog.basicConfig(format=colorlog_format) 20 | logging.getLogger('tensorflow').disabled = True 21 | logger = logging.getLogger(dunder_name) 22 | 23 | if show_debug: 24 | logger.setLevel(logging.DEBUG) 25 | else: 26 | logger.setLevel(logging.INFO) 27 | 28 | # Note: these file outputs are left in place as examples 29 | # Feel free to uncomment and use the outputs as you like 30 | 31 | # Output full log 32 | # fh = logging.FileHandler(os.path.join('data', log', 'trading.log') 33 | # fh.setLevel(logging.DEBUG) 34 | # formatter = logging.Formatter(log_format) 35 | # fh.setFormatter(formatter) 36 | # logger.addHandler(fh) 37 | 38 | # # Output warning log 39 | # fh = logging.FileHandler(os.path.join('data', log', 'trading.warning.log') 40 | # fh.setLevel(logging.WARNING) 41 | # formatter = logging.Formatter(log_format) 42 | # fh.setFormatter(formatter) 43 | # logger.addHandler(fh) 44 | 45 | # # Output error log 46 | # fh = logging.FileHandler(os.path.join('data', log', 'trading.error.log') 47 | # fh.setLevel(logging.ERROR) 48 | # formatter = logging.Formatter(log_format) 49 | # fh.setFormatter(formatter) 50 | # logger.addHandler(fh) 51 | 52 | return logger 53 | -------------------------------------------------------------------------------- /lib/env/trade/SimulatedTradeStrategy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import Tuple, Callable 4 | 5 | from lib.env.trade import BaseTradeStrategy 6 | 7 | 8 | class SimulatedTradeStrategy(BaseTradeStrategy): 9 | def __init__(self, 10 | commissionPercent: float, 11 | maxSlippagePercent: float, 12 | base_precision: int, 13 | asset_precision: int, 14 | min_cost_limit: float, 15 | min_amount_limit: float): 16 | self.commissionPercent = commissionPercent 17 | self.maxSlippagePercent = maxSlippagePercent 18 | self.base_precision = base_precision 19 | self.asset_precision = asset_precision 20 | self.min_cost_limit = min_cost_limit 21 | self.min_amount_limit = min_amount_limit 22 | 23 | def trade(self, 24 | buy_amount: float, 25 | sell_amount: float, 26 | balance: float, 27 | asset_held: float, 28 | current_price: Callable[[str], float]) -> Tuple[float, float, float, float]: 29 | current_price = current_price('Close') 30 | commission = self.commissionPercent / 100 31 | slippage = np.random.uniform(0, self.maxSlippagePercent) / 100 32 | 33 | asset_bought, asset_sold, purchase_cost, sale_revenue = buy_amount, sell_amount, 0, 0 34 | 35 | if buy_amount > 0 and balance >= self.min_cost_limit: 36 | price_adjustment = (1 + commission) * (1 + slippage) 37 | buy_price = round(current_price * price_adjustment, self.base_precision) 38 | purchase_cost = round(buy_price * buy_amount, self.base_precision) 39 | elif sell_amount > 0 and asset_held >= self.min_amount_limit: 40 | price_adjustment = (1 - commission) * (1 - slippage) 41 | sell_price = round(current_price * price_adjustment, self.base_precision) 42 | sale_revenue = round(sell_amount * sell_price, self.base_precision) 43 | 44 | return asset_bought, asset_sold, purchase_cost, sale_revenue 45 | -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from multiprocessing import Process 4 | 5 | from lib.cli.RLTraderCLI import RLTraderCLI 6 | from lib.util.logger import init_logger 7 | from lib.cli.functions import download_data_async 8 | from lib.env.reward import BaseRewardStrategy, IncrementalProfit, WeightedUnrealizedProfit 9 | 10 | np.warnings.filterwarnings('ignore') 11 | 12 | trader_cli = RLTraderCLI() 13 | args = trader_cli.get_args() 14 | 15 | rewards = {"incremental-profit": IncrementalProfit, "weighted-unrealized-profit": WeightedUnrealizedProfit} 16 | reward_strategy = rewards[args.reward_strat] 17 | 18 | 19 | def run_optimize(args, logger): 20 | from lib.RLTrader import RLTrader 21 | 22 | trader = RLTrader(**vars(args), logger=logger, reward_strategy=reward_strategy) 23 | trader.optimize(n_trials=args.trials) 24 | 25 | 26 | if __name__ == '__main__': 27 | logger = init_logger(__name__, show_debug=args.debug) 28 | 29 | if args.command == 'optimize': 30 | n_processes = args.parallel_jobs 31 | 32 | processes = [] 33 | for _ in range(n_processes): 34 | processes.append(Process(target=run_optimize, args=(args, logger))) 35 | 36 | for proc in processes: 37 | proc.start() 38 | 39 | for proc in processes: 40 | proc.join() 41 | 42 | from lib.RLTrader import RLTrader 43 | 44 | trader = RLTrader(**vars(args), logger=logger, reward_strategy=reward_strategy) 45 | 46 | if args.command == 'train': 47 | trader.train(n_epochs=args.epochs, 48 | save_every=args.save_every, 49 | test_trained_model=args.test_trained, 50 | render_test_env=args.render_test, 51 | render_report=args.render_report, 52 | save_report=args.save_report) 53 | elif args.command == 'test': 54 | trader.test(model_epoch=args.model_epoch, 55 | render_env=args.render_env, 56 | render_report=args.render_report, 57 | save_report=args.save_report) 58 | elif args.command == 'update-static-data': 59 | download_data_async() 60 | -------------------------------------------------------------------------------- /lib/util/benchmarks.py: -------------------------------------------------------------------------------- 1 | import ta 2 | from enum import Enum 3 | 4 | 5 | class SIGNALS(Enum): 6 | HOLD = 0 7 | BUY = 1 8 | SELL = 2 9 | 10 | 11 | def trade_strategy(prices, initial_balance, commission, signal_fn): 12 | net_worths = [initial_balance] 13 | balance = initial_balance 14 | amount_held = 0 15 | 16 | for i in range(1, len(prices)): 17 | if amount_held > 0: 18 | net_worths.append(balance + amount_held * prices[i]) 19 | else: 20 | net_worths.append(balance) 21 | 22 | signal = signal_fn(i) 23 | 24 | if signal == SIGNALS.SELL and amount_held > 0: 25 | balance = amount_held * (prices[i] * (1 - commission)) 26 | amount_held = 0 27 | elif signal == SIGNALS.BUY and amount_held == 0: 28 | amount_held = balance / (prices[i] * (1 + commission)) 29 | balance = 0 30 | 31 | return net_worths 32 | 33 | 34 | def buy_and_hodl(prices, initial_balance, commission): 35 | def signal_fn(i): 36 | return SIGNALS.BUY 37 | 38 | return trade_strategy(prices, initial_balance, commission, signal_fn) 39 | 40 | 41 | def rsi_divergence(prices, initial_balance, commission, period=3): 42 | rsi = ta.rsi(prices) 43 | 44 | def signal_fn(i): 45 | if i >= period: 46 | rsiSum = sum(rsi[i - period:i + 1].diff().cumsum().fillna(0)) 47 | priceSum = sum(prices[i - period:i + 1].diff().cumsum().fillna(0)) 48 | 49 | if rsiSum < 0 and priceSum >= 0: 50 | return SIGNALS.SELL 51 | elif rsiSum > 0 and priceSum <= 0: 52 | return SIGNALS.BUY 53 | 54 | return SIGNALS.HOLD 55 | 56 | return trade_strategy(prices, initial_balance, commission, signal_fn) 57 | 58 | 59 | def sma_crossover(prices, initial_balance, commission): 60 | macd = ta.macd(prices) 61 | 62 | def signal_fn(i): 63 | if macd[i] > 0 and macd[i - 1] <= 0: 64 | return SIGNALS.SELL 65 | elif macd[i] < 0 and macd[i - 1] >= 0: 66 | return SIGNALS.BUY 67 | 68 | return SIGNALS.HOLD 69 | 70 | return trade_strategy(prices, initial_balance, commission, signal_fn) 71 | -------------------------------------------------------------------------------- /lib/data/features/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from abc import abstractmethod 5 | from typing import Callable, Iterable, List 6 | 7 | 8 | @abstractmethod 9 | def transform(iterable: Iterable, inplace: bool = True, columns: List[str] = None, transform_fn: Callable[[Iterable], Iterable] = None): 10 | if inplace is True: 11 | transformed_iterable = iterable 12 | else: 13 | transformed_iterable = iterable.copy() 14 | 15 | if isinstance(transformed_iterable, pd.DataFrame): 16 | is_list = False 17 | else: 18 | is_list = True 19 | transformed_iterable = pd.DataFrame(transformed_iterable, columns=columns) 20 | 21 | transformed_iterable.fillna(0, inplace=True) 22 | 23 | if transform_fn is None: 24 | raise NotImplementedError() 25 | 26 | if columns is None: 27 | columns = transformed_iterable.columns 28 | 29 | for column in columns: 30 | transformed_iterable[column] = transform_fn(transformed_iterable[column]) 31 | 32 | transformed_iterable.fillna(method="bfill", inplace=True) 33 | transformed_iterable[np.bitwise_not(np.isfinite(transformed_iterable))] = 0 34 | 35 | if is_list: 36 | transformed_iterable = transformed_iterable.values 37 | 38 | return transformed_iterable 39 | 40 | 41 | def max_min_normalize(iterable: Iterable, inplace: bool = True, columns: List[str] = None): 42 | return transform(iterable, inplace, columns, lambda t_iterable: (t_iterable - t_iterable.min()) / (t_iterable.max() - t_iterable.min())) 43 | 44 | 45 | def mean_normalize(iterable: Iterable, inplace: bool = True, columns: List[str] = None): 46 | return transform(iterable, inplace, columns, lambda t_iterable: (t_iterable - t_iterable.mean()) / t_iterable.std()) 47 | 48 | 49 | def difference(iterable: Iterable, inplace: bool = True, columns: List[str] = None): 50 | return transform(iterable, inplace, columns, lambda t_iterable: t_iterable - t_iterable.shift(1)) 51 | 52 | 53 | def log_and_difference(iterable: Iterable, inplace: bool = True, columns: List[str] = None): 54 | return transform(iterable, inplace, columns, lambda t_iterable: np.log(t_iterable) - np.log(t_iterable).shift(1)) 55 | -------------------------------------------------------------------------------- /lib/data/features/indicators.py: -------------------------------------------------------------------------------- 1 | import ta 2 | import pandas as pd 3 | 4 | 5 | diff = lambda x, y: x - y 6 | abs_diff = lambda x, y: abs(x - y) 7 | 8 | 9 | indicators = [ 10 | ('RSI', ta.rsi, ['Close']), 11 | ('MFI', ta.money_flow_index, ['High', 'Low', 'Close', 'Volume BTC']), 12 | ('TSI', ta.tsi, ['Close']), 13 | ('UO', ta.uo, ['High', 'Low', 'Close']), 14 | ('AO', ta.ao, ['High', 'Close']), 15 | ('MACDDI', ta.macd_diff, ['Close']), 16 | ('VIP', ta.vortex_indicator_pos, ['High', 'Low', 'Close']), 17 | ('VIN', ta.vortex_indicator_neg, ['High', 'Low', 'Close']), 18 | ('VIDIF', abs_diff, ['VIP', 'VIN']), 19 | ('TRIX', ta.trix, ['Close']), 20 | ('MI', ta.mass_index, ['High', 'Low']), 21 | ('CCI', ta.cci, ['High', 'Low', 'Close']), 22 | ('DPO', ta.dpo, ['Close']), 23 | ('KST', ta.kst, ['Close']), 24 | ('KSTS', ta.kst_sig, ['Close']), 25 | ('KSTDI', diff, ['KST', 'KSTS']), 26 | ('ARU', ta.aroon_up, ['Close']), 27 | ('ARD', ta.aroon_down, ['Close']), 28 | ('ARI', diff, ['ARU', 'ARD']), 29 | ('BBH', ta.bollinger_hband, ['Close']), 30 | ('BBL', ta.bollinger_lband, ['Close']), 31 | ('BBM', ta.bollinger_mavg, ['Close']), 32 | ('BBHI', ta.bollinger_hband_indicator, ['Close']), 33 | ('BBLI', ta.bollinger_lband_indicator, ['Close']), 34 | ('KCHI', ta.keltner_channel_hband_indicator, ['High', 'Low', 'Close']), 35 | ('KCLI', ta.keltner_channel_lband_indicator, ['High', 'Low', 'Close']), 36 | ('DCHI', ta.donchian_channel_hband_indicator, ['Close']), 37 | ('DCLI', ta.donchian_channel_lband_indicator, ['Close']), 38 | ('ADI', ta.acc_dist_index, ['High', 'Low', 'Close', 'Volume BTC']), 39 | ('OBV', ta.on_balance_volume, ['Close', 'Volume BTC']), 40 | ('CMF', ta.chaikin_money_flow, ['High', 'Low', 'Close', 'Volume BTC']), 41 | ('FI', ta.force_index, ['Close', 'Volume BTC']), 42 | ('EM', ta.ease_of_movement, ['High', 'Low', 'Close', 'Volume BTC']), 43 | ('VPT', ta.volume_price_trend, ['Close', 'Volume BTC']), 44 | ('NVI', ta.negative_volume_index, ['Close', 'Volume BTC']), 45 | ('DR', ta.daily_return, ['Close']), 46 | ('DLR', ta.daily_log_return, ['Close']) 47 | ] 48 | 49 | 50 | def add_indicators(df) -> pd.DataFrame: 51 | for name, f, arg_names in indicators: 52 | wrapper = lambda func, args: func(*args) 53 | args = [df[arg_name] for arg_name in arg_names] 54 | df[name] = wrapper(f, args) 55 | df.fillna(method='bfill', inplace=True) 56 | return df 57 | -------------------------------------------------------------------------------- /Vagrantfile: -------------------------------------------------------------------------------- 1 | # -*- mode: ruby -*- 2 | # vi: set ft=ruby : 3 | 4 | Vagrant.require_version '>= 2.2' 5 | 6 | VAGRANTFILE_API_VERSION = '2' 7 | 8 | Vagrant.configure(VAGRANTFILE_API_VERSION) do |config| 9 | config.ssh.forward_x11 = true 10 | machine_ip_address = '192.168.181.21' #random IP for low chance of collision 11 | 12 | required_plugins = %w(vagrant-vbguest vagrant-hostmanager) 13 | 14 | # Install plugins if missing 15 | plugins_to_install = required_plugins.select {|plugin| not Vagrant.has_plugin? plugin} 16 | if plugins_to_install.any? 17 | puts "Installing plugins: #{plugins_to_install.join(' ')}" 18 | if system "vagrant plugin install #{plugins_to_install.join(' ')}" 19 | exec "vagrant #{ARGV.join(' ')}" 20 | else 21 | abort 'Installation of one or more plugins has failed. Aborting.' 22 | end 23 | end 24 | 25 | # Configure hosts 26 | if Vagrant.has_plugin?('vagrant-hostmanager') 27 | config.hostmanager.enabled = true 28 | config.hostmanager.manage_host = true 29 | config.hostmanager.manage_guest = true 30 | config.hostmanager.ignore_private_ip = false 31 | config.hostmanager.aliases = ['trader-rl.local'] 32 | end 33 | 34 | # Set auto_update to false, if you do NOT want to check the correct virtual-box-guest-additions version when booting VM 35 | if Vagrant.has_plugin?('vagrant-vbguest') 36 | config.vbguest.auto_update = false 37 | end 38 | 39 | config.vm.define 'trader-rl-vagrant', primary: true do |vm_config| 40 | vm_config.vm.box = 'ubuntu/bionic64' 41 | vm_config.vm.box_check_update = true 42 | vm_config.vm.network 'private_network', ip: machine_ip_address 43 | vm_config.vm.provider 'virtualbox' do |vb| 44 | vb.name = 'trader-rl' 45 | vb.cpus = 8 46 | vb.memory = 20480 47 | end 48 | 49 | vm_config.vm.hostname = 'trader-rl' 50 | vm_config.ssh.insert_key = false 51 | 52 | vm_config.vm.synced_folder '.', '/vagrant', disabled: false 53 | vm_config.vm.provision "default setup", type: "shell", inline: <