├── .env.example ├── .gitignore ├── Makefile ├── README.md ├── images ├── diagram.jpg └── logo_realworldml.png ├── poetry.lock ├── pyproject.toml ├── src ├── __init__.py ├── backfill_technical_indicators.py ├── config.py ├── date_utils.py ├── feature_store_api.py ├── flow_steps.py ├── logger.py ├── old │ ├── 01_basic_llm_chain.py │ └── 02_trading_bot_fake_context.py └── technical_indicators_pipeline.py └── tests └── __init__.py /.env.example: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY=YOUR_OPENAI_API_KEY_GOES_HERE 2 | 3 | COMET_API_KEY=YOUR_COMET_API_KEY_GOES_HERE 4 | COMET_PROJECT_NAME=YOUR_COMET_PROJECT_NAME_GOES_HERE 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # prefect artifacts 2 | .prefectignore 3 | 4 | # python artifacts 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | *.egg-info/ 9 | *.egg 10 | 11 | # Type checking artifacts 12 | .mypy_cache/ 13 | .dmypy.json 14 | dmypy.json 15 | .pyre/ 16 | 17 | # IPython 18 | profile_default/ 19 | ipython_config.py 20 | *.ipynb_checkpoints/* 21 | 22 | # Environments 23 | .python-version 24 | .env 25 | .venv 26 | env/ 27 | venv/ 28 | 29 | # MacOS 30 | .DS_Store 31 | 32 | # Dask 33 | dask-worker-space/ 34 | 35 | # Editors 36 | .idea/ 37 | .vscode/ 38 | 39 | # VCS 40 | .git/ 41 | .hg/ 42 | 43 | .env 44 | .venv/ 45 | *.tar 46 | __pycache__/ 47 | .DS_Store 48 | set_env_variables.sh -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: init debug_technical_indicators run_technical_indicators backfill_technical_indicators 2 | 3 | ### Install ### 4 | 5 | # install Poetry and Python dependencies 6 | init: 7 | curl -sSL https://install.python-poetry.org | python3 - 8 | poetry install 9 | 10 | 11 | ### Technical Indicators Pipeline ### 12 | 13 | # run the feature-pipeline locally and print out the results on the console 14 | debug_technical_indicators: 15 | poetry run python -m bytewax.run "src.dataflow:get_dataflow(execution_mode='DEBUG')" 16 | 17 | # run the feature-pipeline and send the feature to the feature store 18 | run_technical_indicators: 19 | poetry run python -m bytewax.run "src.technical_indicators_pipeline:get_dataflow()" 20 | 21 | # backfills the feature group using historical data 22 | backfill_technical_indicators: 23 | poetry run python src/backfill_technical_indicators.py --from_day $(from_day) --product_id XBT/USD 24 | 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

Build a trading bot with OpenAI GPT-3.5, real-time data and prompt experimentations

4 | 5 |
6 | 7 |
8 | 9 | ## TODOs 10 | 11 | - [ ] Build and run dataflow for technical indicators 12 | 13 | 14 | - [ ] Build news summarization pipeline. 15 | - [ ] Create online Feature Store 16 | - [ ] `.env` file with credentials 17 | - [ ] Copy working code. 18 | - [ ] 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /images/diagram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Paulescu/trading-bot-gpt/2bf95306d6b0488de7bfdff5da784113eacf2b68/images/diagram.jpg -------------------------------------------------------------------------------- /images/logo_realworldml.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Paulescu/trading-bot-gpt/2bf95306d6b0488de7bfdff5da784113eacf2b68/images/logo_realworldml.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "src" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Pau "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.10,<3.11" 10 | langchain = "^0.0.286" 11 | python-dotenv = "^1.0.0" 12 | openai = "^0.28.0" 13 | comet-llm = "^1.3.0" 14 | bytewax = "==0.16.*" 15 | pandas = "2.0.0" 16 | hsfs = "^3.4.2" 17 | hopsworks = "^3.4.2" 18 | prefect = "^2.13.7" 19 | 20 | 21 | [build-system] 22 | requires = ["poetry-core"] 23 | build-backend = "poetry.core.masonry.api" 24 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Paulescu/trading-bot-gpt/2bf95306d6b0488de7bfdff5da784113eacf2b68/src/__init__.py -------------------------------------------------------------------------------- /src/backfill_technical_indicators.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta, date 2 | from argparse import ArgumentParser 3 | from typing import List 4 | 5 | import pytz 6 | import requests 7 | import pandas as pd 8 | from prefect import task, flow 9 | from prefect.task_runners import SequentialTaskRunner 10 | 11 | from src.feature_store_api import save_data_to_offline_feature_group 12 | from src import config 13 | from src.logger import get_console_logger 14 | # from src.dataflow import run_dataflow_in_backfill_mode 15 | from src.technical_indicators_pipeline import run_dataflow_in_backfill_mode 16 | 17 | logger = get_console_logger() 18 | 19 | PAIR = 'xbtusd' 20 | OHLC_FREQUENCY_IN_MINUTES = 1 21 | 22 | @task(retries=3, retry_delay_seconds=60) 23 | def fetch_data_from_kraken_api( 24 | product_id: str, 25 | since_nano_seconds: int 26 | ) -> List[List]: 27 | """ 28 | Fetches data from Kraken API for the given `pair` from `since_nano_seconds` 29 | 30 | Args: 31 | pair (str): currency pair we fetch data for 32 | last_ts (int): timestamp in seconds 33 | 34 | Returns: 35 | dict: response from Kraken API 36 | """ 37 | MAP_PRODUCT_ID_TO_HISTORICAL_API_PAIR_NAME = { 38 | 'XBT/USD': 'xbtusd', 39 | } 40 | pair = MAP_PRODUCT_ID_TO_HISTORICAL_API_PAIR_NAME[product_id] 41 | 42 | # build URL we need to fetch data from 43 | url = f"https://api.kraken.com/0/public/Trades?pair={pair}&since={since_nano_seconds:.0f}" 44 | 45 | # fetch data 46 | response = requests.get(url) 47 | 48 | # extract data from response 49 | MAP_PRODUCT_ID_TO_HISTORIAL_API_RESPONSE_KEY = { 50 | 'XBT/USD': 'XXBTZUSD' 51 | } 52 | response_key = MAP_PRODUCT_ID_TO_HISTORIAL_API_RESPONSE_KEY[product_id] 53 | trade_params = response.json()["result"][response_key] 54 | 55 | return trade_params 56 | 57 | @flow 58 | def fetch_historical_data_one_day(day: datetime, product_id: str) -> pd.DataFrame: 59 | """""" 60 | # time range we want to fetch data for 61 | from_ts = day.timestamp() 62 | to_ts = (day + timedelta(days=1)).timestamp() 63 | # to_ts = (day + timedelta(hours=1)).timestamp() 64 | 65 | trades = [] 66 | last_ts = from_ts 67 | while last_ts < to_ts: 68 | 69 | # fetch data from Kraken API 70 | trade_params = fetch_data_from_kraken_api( 71 | product_id, 72 | since_nano_seconds=last_ts*1000000000 73 | ) 74 | 75 | # create a list of Dict with the results 76 | trades_in_batch = [ 77 | { 78 | 'price': params[0], 79 | 'volume': params[1], 80 | 'timestamp': params[2], 81 | 'product_id': product_id, 82 | } 83 | for params in trade_params 84 | ] 85 | 86 | if len(trades_in_batch) == 0: 87 | logger.info(f'No more data for {product_id} from {datetime.utcfromtimestamp(last_ts)}') 88 | break 89 | 90 | # checking timestamps 91 | from_date = datetime.utcfromtimestamp(trades_in_batch[0]['timestamp']) 92 | to_date = datetime.utcfromtimestamp(trades_in_batch[-1]['timestamp']) 93 | logger.info(f'trades_in_batch from {from_date} to {to_date}') 94 | 95 | # add trades to list of trades 96 | trades += trades_in_batch 97 | 98 | # update last_ts for the next iteration 99 | last_ts = trades[-1]['timestamp'] 100 | 101 | # TODO: remove this break. It is here to speed up the debugging process 102 | # break 103 | 104 | # drop trades that might fall outside of the time window 105 | trades = [t for t in trades if t['timestamp'] >= from_ts and t['timestamp'] <= to_ts] 106 | # logger.info(f'Fetched trade data from {trades[0]["timestamp"]} to {trades[-1]["timestamp"]}') 107 | logger.info(f'{len(trades)=}') 108 | 109 | # # convert trades to pandas dataframe 110 | trades = pd.DataFrame(trades) 111 | 112 | # # set correct dtypes 113 | trades['price'] = trades['price'].astype(float) 114 | trades['volume'] = trades['volume'].astype(float) 115 | trades['timestamp'] = trades['timestamp'].astype(int) 116 | trades['product_id'] = trades['product_id'].astype(str) 117 | 118 | return trades 119 | 120 | def save_data_to_csv_file(data: pd.DataFrame) -> str: 121 | """Saves data to a temporary csv file""" 122 | import time 123 | tmp_file_path = f'/tmp/{int(time.time())}.csv' 124 | data.to_csv(tmp_file_path) 125 | return tmp_file_path 126 | 127 | @task 128 | def transform_trades_to_ohlc(trades: pd.DataFrame) -> pd.DataFrame: 129 | """Transforms raw trades to OHLC data""" 130 | 131 | # convert ts column to pd.Timestamp 132 | # trades.index = pd.to_datetime(trades['ts'], unit='s') 133 | 134 | # logger.info('Saving trades to temporary file') 135 | file_path = save_data_to_csv_file(trades) 136 | logger.info(f'Saved trades to temporary file {file_path}') 137 | 138 | # breakpoint() 139 | 140 | logger.info('Creating dataflow for backfilling and running it') 141 | # from src.dataflow import run_dataflow_in_backfill_mode 142 | ohlc = run_dataflow_in_backfill_mode(input_file=file_path) 143 | 144 | # breakpoint() 145 | 146 | # remove temporary file 147 | logger.info(f'Removing temporary file {file_path}') 148 | import os 149 | os.remove(file_path) 150 | 151 | return ohlc 152 | 153 | @task(retries=3, retry_delay_seconds=60) 154 | def save_ohlc_to_feature_store(ohlc: pd.DataFrame) -> None: 155 | """Saves OHLC data to the feature store""" 156 | # save OHLC data to offline feature store 157 | logger.info('Saving OHLC data to offline feature store') 158 | save_data_to_offline_feature_group( 159 | data=ohlc, 160 | feature_group_config=config.FEATURE_GROUP_METADATA 161 | ) 162 | 163 | @flow(task_runner=SequentialTaskRunner()) 164 | def backfill_one_day(day: datetime, product_id: str): 165 | """Backfills OHLC data in the feature store for the given `day`""" 166 | 167 | # fetch trade data from external API 168 | trades: pd.DataFrame = fetch_historical_data_one_day(day, product_id) 169 | 170 | # transform trade data to OHLC data 171 | ohlc: pd.DataFrame = transform_trades_to_ohlc(trades) 172 | 173 | # push OHLC data to the feature store 174 | logger.info(f'Pushing OHLC data to offline feature group, {day=} {product_id=}') 175 | save_ohlc_to_feature_store(ohlc) 176 | 177 | @flow 178 | def backfill_range_dates(from_day: datetime, to_day: datetime, product_id: str): 179 | """ 180 | Backfills OHLC data in the feature store for the ranges of days 181 | between `from_day` and `to_day` 182 | """ 183 | days = pd.date_range(from_day, to_day, freq='D') 184 | for day in days: 185 | backfill_one_day(day, product_id=product_id) 186 | 187 | 188 | if __name__ == "__main__": 189 | 190 | parser = ArgumentParser() 191 | parser.add_argument( 192 | '--from_day', 193 | type=lambda s: datetime.strptime(s, '%Y-%m-%d').replace(tzinfo=pytz.UTC) 194 | ) 195 | parser.add_argument( 196 | '--to_day', 197 | type=lambda s: datetime.strptime(s, '%Y-%m-%d').replace(tzinfo=pytz.UTC), 198 | default=datetime.utcnow().replace(tzinfo=pytz.UTC) 199 | ) 200 | parser.add_argument( 201 | '--day', 202 | type=lambda s: datetime.strptime(s, '%Y-%m-%d').replace(tzinfo=pytz.UTC) 203 | ) 204 | 205 | parser.add_argument( 206 | '--product_id', type=str, 207 | help="For example: XBT/USD", 208 | ) 209 | 210 | args = parser.parse_args() 211 | 212 | if args.day: 213 | logger.info(f'Starting backfilling process for {args.day}') 214 | backfill_one_day(day=args.day, product_id=args.product_id) 215 | 216 | else: 217 | assert args.from_day <= args.to_day, "from_day has to be smaller or equal than to_day" 218 | logger.info(f'Starting backfilling process from {args.from_day} to {args.to_day}') 219 | backfill_range_dates(from_day=args.from_day, to_day=args.to_day, product_id=args.product_id) -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | from src.feature_store_api import FeatureGroupConfig, FeatureViewConfig 2 | 3 | WINDOW_SECONDS = 60 4 | PRODUCT_IDS = [ 5 | "BTC/USD", 6 | # "ETH/USD", 7 | ] 8 | 9 | FEATURE_GROUP_METADATA = FeatureGroupConfig( 10 | name=f'ohlc_{WINDOW_SECONDS}_sec', 11 | version=1, 12 | description=f"OHLC data with technical indicators every {WINDOW_SECONDS} seconds", 13 | primary_key=[ 'product_id', 'time'], 14 | event_time='time', 15 | online_enabled=True, 16 | ) 17 | 18 | FEATURE_VIEW_CONFIG = FeatureViewConfig( 19 | name=f'ohlc_{WINDOW_SECONDS}_sec_view', 20 | version=1, 21 | feature_group=FEATURE_GROUP_METADATA, 22 | ) -------------------------------------------------------------------------------- /src/date_utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | 3 | def str2epoch(x: str) -> int: 4 | return str2datetime(x).timestamp() 5 | 6 | def str2datetime(s: str) -> datetime: 7 | return datetime.strptime(s, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc) 8 | 9 | def epoch2datetime(epoch: int) -> datetime: 10 | return datetime.utcfromtimestamp(epoch).replace(tzinfo=timezone.utc) -------------------------------------------------------------------------------- /src/feature_store_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | from dataclasses import dataclass 4 | 5 | import hsfs 6 | import hopsworks 7 | import pandas as pd 8 | 9 | from src.logger import get_console_logger 10 | 11 | logger = get_console_logger() 12 | 13 | # load Feature Store credentials from environment variables 14 | HOPSWORKS_PROJECT_NAME = os.environ['HOPSWORKS_PROJECT_NAME'] 15 | HOPSWORKS_API_KEY = os.environ['HOPSWORKS_API_KEY'] 16 | 17 | @dataclass 18 | class FeatureGroupConfig: 19 | name: str 20 | version: int 21 | description: str 22 | primary_key: List[str] 23 | event_time: str 24 | online_enabled: bool 25 | 26 | @dataclass 27 | class FeatureViewConfig: 28 | name: str 29 | version: int 30 | feature_group: FeatureGroupConfig 31 | 32 | 33 | def get_feature_store() -> hsfs.feature_store.FeatureStore: 34 | """Connects to Hopsworks and returns a pointer to the feature store 35 | 36 | Returns: 37 | hsfs.feature_store.FeatureStore: pointer to the feature store 38 | """ 39 | project = hopsworks.login( 40 | project=HOPSWORKS_PROJECT_NAME, 41 | api_key_value=HOPSWORKS_API_KEY 42 | ) 43 | return project.get_feature_store() 44 | 45 | 46 | def get_feature_group( 47 | feature_group_metadata: FeatureGroupConfig 48 | ) -> hsfs.feature_group.FeatureGroup: 49 | """Connects to the feature store and returns a pointer to the given 50 | feature group `name` 51 | 52 | Args: 53 | name (str): name of the feature group 54 | version (Optional[int], optional): _description_. Defaults to 1. 55 | 56 | Returns: 57 | hsfs.feature_group.FeatureGroup: pointer to the feature group 58 | """ 59 | return get_feature_store().get_or_create_feature_group( 60 | name=feature_group_metadata.name, 61 | version=feature_group_metadata.version, 62 | description=feature_group_metadata.description, 63 | primary_key=feature_group_metadata.primary_key, 64 | event_time=feature_group_metadata.event_time, 65 | online_enabled=feature_group_metadata.online_enabled 66 | ) 67 | 68 | 69 | def get_or_create_feature_view( 70 | feature_view_metadata: FeatureViewConfig 71 | ) -> hsfs.feature_view.FeatureView: 72 | """""" 73 | 74 | # get pointer to the feature store 75 | feature_store = get_feature_store() 76 | 77 | # get pointer to the feature group 78 | # from src.config import FEATURE_GROUP_METADATA 79 | feature_group = feature_store.get_feature_group( 80 | name=feature_view_metadata.feature_group.name, 81 | version=feature_view_metadata.feature_group.version 82 | ) 83 | 84 | # create feature view if it doesn't exist 85 | try: 86 | feature_store.create_feature_view( 87 | name=feature_view_metadata.name, 88 | version=feature_view_metadata.version, 89 | query=feature_group.select_all() 90 | ) 91 | except: 92 | # logger.info("Feature view already exists, skipping creation.") 93 | logger.info("Feature view already exists, skipping creation.") 94 | 95 | # get feature view 96 | feature_store = get_feature_store() 97 | feature_view = feature_store.get_feature_view( 98 | name=feature_view_metadata.name, 99 | version=feature_view_metadata.version, 100 | ) 101 | 102 | return feature_view 103 | 104 | # from prefect import task 105 | # @task(retries=3, retry_delay_seconds=60) 106 | def save_data_to_offline_feature_group( 107 | data: pd.DataFrame, 108 | feature_group_config: FeatureGroupConfig 109 | ) -> None: 110 | """""" 111 | feature_group = get_feature_group(feature_group_config) 112 | feature_group.insert(data, write_options={"wait_for_job": False}) -------------------------------------------------------------------------------- /src/flow_steps.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, List, Generator, Any, Optional, Callable 2 | import json 3 | from datetime import datetime, timedelta, timezone 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from websocket import create_connection 9 | from bytewax.inputs import DynamicInput, StatelessSource 10 | from bytewax.connectors.files import FileInput, CSVInput 11 | from bytewax.dataflow import Dataflow 12 | from bytewax.window import EventClockConfig, TumblingWindow #Config 13 | from bytewax.outputs import StatelessSink, DynamicOutput 14 | 15 | from src.config import PRODUCT_IDS, WINDOW_SECONDS 16 | from src.feature_store_api import FeatureGroupConfig, get_feature_group 17 | from src.date_utils import epoch2datetime 18 | from src.logger import get_console_logger 19 | 20 | logger = get_console_logger() 21 | 22 | def connect_to_input_csv_file(flow: Dataflow, input_file: Path) -> Dataflow: 23 | 24 | # connect to input file 25 | flow.input("input", CSVInput(input_file)) 26 | 27 | # extract product_id as key and trades as list of dicts 28 | def extract_key_and_trades(data: Dict) -> Tuple[str, List[Dict]]: 29 | """""" 30 | product_id = data['product_id'] 31 | trades = [ 32 | { 33 | 'price': float(data['price']), 34 | 'volume': float(data['volume']), 35 | 'timestamp': float(data['timestamp']), 36 | } 37 | ] 38 | return (product_id, trades) 39 | 40 | flow.map(extract_key_and_trades) 41 | 42 | return flow 43 | 44 | def connect_to_kraken_websocket(flow: Dataflow) -> Dataflow: 45 | """""" 46 | connect_to_input_socket(flow) 47 | format_websocket_event(flow) 48 | 49 | return flow 50 | 51 | def connect_to_input_socket(flow: Dataflow) -> Dataflow: 52 | """Connects the given dataflow to the Coinbase websocket 53 | 54 | Args: 55 | flow (Dataflow): _description_ 56 | """ 57 | class KrakenSource(StatelessSource): 58 | def __init__(self, product_ids): 59 | 60 | self.product_ids = product_ids 61 | 62 | self.ws = create_connection("wss://ws.kraken.com/") 63 | self.ws.send( 64 | json.dumps( 65 | { 66 | "event": "subscribe", 67 | "subscription": {"name":"trade"}, 68 | "pair": product_ids, 69 | } 70 | ) 71 | ) 72 | # The first msg is just a confirmation that we have subscribed. 73 | logger.info(f'First message from websocket: {self.ws.recv()}') 74 | 75 | def next(self) -> Optional[Any]: 76 | 77 | event = self.ws.recv() 78 | 79 | if 'heartbeat' in event: 80 | # logger.info(f'Heartbeat event: {event}') 81 | return None 82 | 83 | if 'channelID' in event: 84 | # logger.info(f'Subscription event: {event}') 85 | return None 86 | 87 | # logger.info(f'{event=}') 88 | return event 89 | 90 | class KrakenInput(DynamicInput): 91 | 92 | def build(self, worker_index, worker_count): 93 | prods_per_worker = int(len(PRODUCT_IDS) / worker_count) 94 | product_ids = PRODUCT_IDS[ 95 | int(worker_index * prods_per_worker) : int( 96 | worker_index * prods_per_worker + prods_per_worker 97 | ) 98 | ] 99 | return KrakenSource(product_ids) 100 | 101 | flow.input("input", KrakenInput()) 102 | 103 | return flow 104 | 105 | def format_websocket_event(flow: Dataflow) -> Dataflow: 106 | 107 | # string to json 108 | flow.map(json.loads) 109 | 110 | # extract product_id as key and trades as list of dicts 111 | def extract_key_and_trades(data: Dict) -> Tuple[str, List[Dict]]: 112 | """""" 113 | product_id = data[3] 114 | trades = [ 115 | { 116 | 'price': float(t[0]), 117 | 'volume': float(t[1]), 118 | 'timestamp': float(t[2]), 119 | } 120 | for t in data[1] 121 | ] 122 | return (product_id, trades) 123 | 124 | flow.map(extract_key_and_trades) 125 | 126 | return flow 127 | 128 | def add_tumbling_window(flow: Dataflow, window_seconds: int) -> Dataflow: 129 | 130 | def get_event_time(trades: List[Dict]) -> datetime: 131 | """ 132 | This function instructs the event clock on how to retrieve the 133 | event's datetime from the input. 134 | """ 135 | # use timestamp from the first trade in the list 136 | return epoch2datetime(trades[0]['timestamp']) 137 | 138 | def build_array() -> np.array: 139 | """_summary_ 140 | 141 | Returns: 142 | np.array: _description_ 143 | """ 144 | return np.empty((0,3)) 145 | 146 | def acc_values(previous_data: np.array, trades: List[Dict]) -> np.array: 147 | """ 148 | This is the accumulator function, and outputs a numpy array of time and price 149 | """ 150 | # TODO: fix this to add all trades, not just the first in the event 151 | return np.insert(previous_data, 0, 152 | np.array((trades[0]['timestamp'], trades[0]['price'], trades[0]['volume'])), 0) 153 | 154 | 155 | # Configure the `fold_window` operator to use the event time 156 | clock_config = EventClockConfig( 157 | get_event_time, 158 | wait_for_system_duration=timedelta(seconds=0) 159 | ) 160 | 161 | window_config = TumblingWindow(length=timedelta(seconds=window_seconds), 162 | align_to=datetime(2023, 4, 3, 0, 0, 0, tzinfo=timezone.utc)) 163 | 164 | flow.fold_window(f"{window_seconds}_sec", 165 | clock_config, 166 | window_config, 167 | build_array, 168 | acc_values) 169 | return flow 170 | 171 | def aggregate_raw_trades_as_ohlc(flow: Dataflow) -> Dataflow: 172 | 173 | # compute OHLC for the window 174 | def calculate_features(ticker_data: Tuple[str, np.array]) -> Tuple[str, Dict]: 175 | """Aggregate trade data in window 176 | 177 | Args: 178 | ticker__data (Tuple[str, np.array]): product_id, data 179 | 180 | Returns: 181 | Tuple[str, Dict]: product_id, Dict with keys 182 | - time 183 | - open 184 | - high 185 | - low 186 | - close 187 | - volume 188 | """ 189 | def round_timestamp(timestamp_seconds: int) -> int: 190 | import math 191 | return int(math.floor(timestamp_seconds / WINDOW_SECONDS * 1.0)) * WINDOW_SECONDS 192 | 193 | ticker, data = ticker_data 194 | ohlc = { 195 | "time": round_timestamp(data[-1][0]), 196 | # "time": data[-1][0], 197 | "open": data[:,1][-1], 198 | "high": np.amax(data[:,1]), 199 | "low": np.amin(data[:,1]), 200 | "close": data[:,1][0], 201 | "volume": np.sum(data[:,2]) 202 | } 203 | return (ticker, ohlc) 204 | 205 | flow.map(calculate_features) 206 | return flow 207 | 208 | def tuple_to_dict(flow: Dataflow) -> Dataflow: 209 | 210 | def _tuple_to_dict(key__dict: Tuple[str, Dict]) -> Dict: 211 | key, dict = key__dict 212 | dict['product_id'] = key 213 | 214 | # TODO: fix this upstream 215 | dict['time'] = int(dict['time']) 216 | 217 | return dict 218 | 219 | flow.map(_tuple_to_dict) 220 | return flow 221 | 222 | def save_output_to_list(flow: Dataflow, output: List[Any]) -> Dataflow: 223 | """Saves output to a list""" 224 | from bytewax.testing import TestingOutput 225 | flow.output("output", TestingOutput(output)) 226 | return flow 227 | 228 | def save_output_to_online_feature_store( 229 | flow: Dataflow, 230 | feature_group_metadata: FeatureGroupConfig 231 | ) -> Dataflow: 232 | class HopsworksSink(StatelessSink): 233 | def __init__(self, feature_group): 234 | self.feature_group = feature_group 235 | 236 | def write(self, item): 237 | df = pd.DataFrame.from_records([item]) 238 | logger.info(f'Saving {item} to online feature store...') 239 | self.feature_group.insert(df, write_options={"start_offline_backfill": False}) 240 | 241 | class HopsworksOutput(DynamicOutput): 242 | def __init__(self, feature_group_metadata): 243 | self.feature_group = get_feature_group(feature_group_metadata) 244 | 245 | def build(self, worker_index, worker_count): 246 | return HopsworksSink(self.feature_group) 247 | 248 | flow.output("hopsworks", HopsworksOutput(feature_group_metadata)) 249 | return flow -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional 3 | 4 | def get_console_logger(name: Optional[str] = 'trading_bot_gpt') -> logging.Logger: 5 | 6 | # Create logger if it doesn't exist 7 | logger = logging.getLogger(name) 8 | logger.setLevel(logging.DEBUG) 9 | 10 | return logger -------------------------------------------------------------------------------- /src/old/01_basic_llm_chain.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from langchain.llms import OpenAI 4 | from langchain import PromptTemplate, LLMChain 5 | from dotenv import load_dotenv 6 | from comet_llm import Span, end_chain, start_chain 7 | 8 | load_dotenv() 9 | 10 | OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] 11 | 12 | def get_llm_chain() -> LLMChain: 13 | 14 | # prompte template 15 | template = """Question: {question} 16 | Answer: Let's think step by step.""" 17 | 18 | prompt = PromptTemplate(template=template, input_variables=["question"]) 19 | 20 | # llm 21 | llm = OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0) 22 | 23 | # llm chain 24 | llm_chain = LLMChain(prompt=prompt, llm=llm) 25 | 26 | return llm_chain 27 | 28 | llm_chain = get_llm_chain() 29 | 30 | def main(question: str): 31 | 32 | start_chain(inputs={"question": question}) 33 | 34 | with Span( 35 | category="llm-reasoning", 36 | inputs={"question": question}, 37 | ) as span: 38 | response = llm_chain.run(question) 39 | print(response) 40 | span.set_outputs(outputs={"response": response}) 41 | 42 | end_chain(outputs={"response": response}) 43 | 44 | question = "What NFL team won the Super Bowl in the year Justin Beiber was born?" 45 | main(question) -------------------------------------------------------------------------------- /src/old/02_trading_bot_fake_context.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List 3 | 4 | from langchain.llms import OpenAI 5 | from langchain import PromptTemplate, LLMChain 6 | from dotenv import load_dotenv 7 | from comet_llm import Span, end_chain, start_chain 8 | 9 | load_dotenv() 10 | 11 | OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] 12 | 13 | def retrieve_technical_indicators() -> Dict[str, float]: 14 | """Retrieve technical indicators from a database.""" 15 | return {'RSI': 0.5, 'MACD': 0.5, 'SMA': 0.5} 16 | 17 | 18 | def retrieve_recent_news() -> List[str]: 19 | """Retrieve recent news from a database.""" 20 | return [ 21 | 'ETH is about to upgrade its protocol', 22 | 'DeFI sector is growing faster than expected.' 23 | ] 24 | 25 | def get_llm_chain() -> LLMChain: 26 | 27 | # prompte template 28 | template = """ 29 | I will give you the current technical indicators for ETH/USD and some recent financial news and you wil tell me if the asset price will go up or down in the next 1 hour. 30 | 31 | Dont tell me that predicting short-term price movements is highly uncertain please, but be brave! 32 | 33 | Format the output as a python dictionary with two keys: signal, and explanation. Signal is either -1 (meaning price goes down) or 1 (meaning price goes up). Explanation is a string exposing the reasoning behind your prediction. Please provide no more text in your response apart from this python dictionary. 34 | 35 | # technical indicators 36 | {technical_indicators} 37 | 38 | # news 39 | {recent_news} 40 | """ 41 | prompt = PromptTemplate(template=template, 42 | input_variables=["technical_indicators", "recent_news"]) 43 | 44 | # llm 45 | llm = OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0) 46 | 47 | # llm chain 48 | llm_chain = LLMChain(prompt=prompt, llm=llm) 49 | 50 | return llm_chain 51 | 52 | def main(current_time: int): 53 | 54 | start_chain(inputs={"current_time": current_time}) 55 | 56 | # retrieve current technical indicators 57 | with Span( 58 | category="context-retrieval", 59 | name="Retrieve technical indicators", 60 | inputs={"current_time": current_time}, 61 | ) as span: 62 | technical_indicators = retrieve_technical_indicators() 63 | span.set_outputs(outputs={"technical_indicators": technical_indicators}) 64 | 65 | # retrieve recent news 66 | with Span( 67 | category="context-retrieval", 68 | name="Retrieve recent news", 69 | inputs={"current_time": current_time}, 70 | ) as span: 71 | recent_news = retrieve_recent_news() 72 | span.set_outputs(outputs={"recent_news": recent_news}) 73 | 74 | # llm that will reason about the context (technical indicators + recent news) 75 | # and output a prediction 76 | llm_chain = get_llm_chain() 77 | 78 | with Span( 79 | category="llm-reasoning", 80 | inputs={"technical_indicators": technical_indicators,"recent_news": recent_news}, 81 | ) as span: 82 | response = llm_chain.run(technical_indicators=technical_indicators, recent_news=recent_news) 83 | print(response) 84 | span.set_outputs(outputs={"response": response}) 85 | 86 | end_chain(outputs={"response": response}) 87 | 88 | from datetime import datetime 89 | 90 | current_time = int(datetime.utcnow().timestamp()) 91 | main(current_time) -------------------------------------------------------------------------------- /src/technical_indicators_pipeline.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, List, Any 3 | 4 | import pandas as pd 5 | from bytewax.dataflow import Dataflow 6 | from bytewax.connectors.stdio import StdOutput 7 | from bytewax.testing import run_main 8 | 9 | from src import config 10 | from src.flow_steps import ( 11 | connect_to_kraken_websocket, 12 | connect_to_input_csv_file, 13 | add_tumbling_window, 14 | aggregate_raw_trades_as_ohlc, 15 | tuple_to_dict, 16 | save_output_to_online_feature_store, 17 | save_output_to_list, 18 | ) 19 | from src.logger import get_console_logger 20 | 21 | logger = get_console_logger() 22 | 23 | 24 | def run_dataflow_in_backfill_mode(input_file: str) -> pd.DataFrame: 25 | """""" 26 | logger.info('Building dataflow in BACKFILL mode') 27 | ohlc = [] 28 | 29 | flow = get_dataflow( 30 | execution_mode='BACKFILL', 31 | backfill_input_file=input_file, 32 | backfill_output=ohlc 33 | ) 34 | 35 | logger.info('Running dataflow in BACKFILL mode') 36 | run_main(flow) 37 | 38 | return pd.DataFrame(ohlc) 39 | 40 | 41 | def get_dataflow( 42 | execution_mode: Optional[str] = 'LIVE', 43 | backfill_input_file: Optional[Path] = None, 44 | backfill_output: List[Any] = None, 45 | ) -> Dataflow: 46 | """Constructs and returns a ByteWax Dataflow 47 | 48 | Args: 49 | window_seconds (int) 50 | 51 | Returns: 52 | Dataflow: 53 | """ 54 | window_seconds = config.WINDOW_SECONDS 55 | assert execution_mode in ['LIVE', 'BACKFILL', 'DEBUG'] 56 | 57 | flow = Dataflow() 58 | 59 | # connect dataflow to an input, either websocket or local file 60 | if execution_mode == 'BACKFILL': 61 | logger.info('BACKFILL MODE ON!') 62 | logger.info('Adding connector to local file to backfill') 63 | 64 | if not backfill_input_file: 65 | raise Exception('You need to provide `backfill_input_file` when your execution_mode is BACKFILL') 66 | 67 | connect_to_input_csv_file(flow, backfill_input_file) 68 | else: 69 | logger.info('Adding connector to input socket') 70 | connect_to_kraken_websocket(flow) 71 | 72 | # add tumbling window 73 | add_tumbling_window(flow, window_seconds) 74 | 75 | # reduce multiple trades into a single OHLC 76 | aggregate_raw_trades_as_ohlc(flow) 77 | 78 | # add technical indicators 79 | # TODO 80 | 81 | tuple_to_dict(flow) 82 | 83 | # send output to stdout, online feature store or in-memory list 84 | if execution_mode == 'DEBUG': 85 | logger.info('Dataflow in debug mode. Output to stdout') 86 | flow.output("output", StdOutput()) 87 | 88 | elif execution_mode == 'LIVE': 89 | from src.config import FEATURE_GROUP_METADATA 90 | save_output_to_online_feature_store(flow, FEATURE_GROUP_METADATA) 91 | 92 | elif execution_mode == 'BACKFILL': 93 | logger.info('Pushing data to Offline Feature Store') 94 | save_output_to_list(flow, backfill_output) 95 | 96 | return flow -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Paulescu/trading-bot-gpt/2bf95306d6b0488de7bfdff5da784113eacf2b68/tests/__init__.py --------------------------------------------------------------------------------