├── .env.example
├── .gitignore
├── Makefile
├── README.md
├── images
├── diagram.jpg
└── logo_realworldml.png
├── poetry.lock
├── pyproject.toml
├── src
├── __init__.py
├── backfill_technical_indicators.py
├── config.py
├── date_utils.py
├── feature_store_api.py
├── flow_steps.py
├── logger.py
├── old
│ ├── 01_basic_llm_chain.py
│ └── 02_trading_bot_fake_context.py
└── technical_indicators_pipeline.py
└── tests
└── __init__.py
/.env.example:
--------------------------------------------------------------------------------
1 | OPENAI_API_KEY=YOUR_OPENAI_API_KEY_GOES_HERE
2 |
3 | COMET_API_KEY=YOUR_COMET_API_KEY_GOES_HERE
4 | COMET_PROJECT_NAME=YOUR_COMET_PROJECT_NAME_GOES_HERE
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # prefect artifacts
2 | .prefectignore
3 |
4 | # python artifacts
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 | *.egg-info/
9 | *.egg
10 |
11 | # Type checking artifacts
12 | .mypy_cache/
13 | .dmypy.json
14 | dmypy.json
15 | .pyre/
16 |
17 | # IPython
18 | profile_default/
19 | ipython_config.py
20 | *.ipynb_checkpoints/*
21 |
22 | # Environments
23 | .python-version
24 | .env
25 | .venv
26 | env/
27 | venv/
28 |
29 | # MacOS
30 | .DS_Store
31 |
32 | # Dask
33 | dask-worker-space/
34 |
35 | # Editors
36 | .idea/
37 | .vscode/
38 |
39 | # VCS
40 | .git/
41 | .hg/
42 |
43 | .env
44 | .venv/
45 | *.tar
46 | __pycache__/
47 | .DS_Store
48 | set_env_variables.sh
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: init debug_technical_indicators run_technical_indicators backfill_technical_indicators
2 |
3 | ### Install ###
4 |
5 | # install Poetry and Python dependencies
6 | init:
7 | curl -sSL https://install.python-poetry.org | python3 -
8 | poetry install
9 |
10 |
11 | ### Technical Indicators Pipeline ###
12 |
13 | # run the feature-pipeline locally and print out the results on the console
14 | debug_technical_indicators:
15 | poetry run python -m bytewax.run "src.dataflow:get_dataflow(execution_mode='DEBUG')"
16 |
17 | # run the feature-pipeline and send the feature to the feature store
18 | run_technical_indicators:
19 | poetry run python -m bytewax.run "src.technical_indicators_pipeline:get_dataflow()"
20 |
21 | # backfills the feature group using historical data
22 | backfill_technical_indicators:
23 | poetry run python src/backfill_technical_indicators.py --from_day $(from_day) --product_id XBT/USD
24 |
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
Build a trading bot with OpenAI GPT-3.5, real-time data and prompt experimentations
4 |

5 |
6 |
7 |
8 |
9 | ## TODOs
10 |
11 | - [ ] Build and run dataflow for technical indicators
12 |
13 |
14 | - [ ] Build news summarization pipeline.
15 | - [ ] Create online Feature Store
16 | - [ ] `.env` file with credentials
17 | - [ ] Copy working code.
18 | - [ ]
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/images/diagram.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Paulescu/trading-bot-gpt/2bf95306d6b0488de7bfdff5da784113eacf2b68/images/diagram.jpg
--------------------------------------------------------------------------------
/images/logo_realworldml.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Paulescu/trading-bot-gpt/2bf95306d6b0488de7bfdff5da784113eacf2b68/images/logo_realworldml.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "src"
3 | version = "0.1.0"
4 | description = ""
5 | authors = ["Pau "]
6 | readme = "README.md"
7 |
8 | [tool.poetry.dependencies]
9 | python = ">=3.10,<3.11"
10 | langchain = "^0.0.286"
11 | python-dotenv = "^1.0.0"
12 | openai = "^0.28.0"
13 | comet-llm = "^1.3.0"
14 | bytewax = "==0.16.*"
15 | pandas = "2.0.0"
16 | hsfs = "^3.4.2"
17 | hopsworks = "^3.4.2"
18 | prefect = "^2.13.7"
19 |
20 |
21 | [build-system]
22 | requires = ["poetry-core"]
23 | build-backend = "poetry.core.masonry.api"
24 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Paulescu/trading-bot-gpt/2bf95306d6b0488de7bfdff5da784113eacf2b68/src/__init__.py
--------------------------------------------------------------------------------
/src/backfill_technical_indicators.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timedelta, date
2 | from argparse import ArgumentParser
3 | from typing import List
4 |
5 | import pytz
6 | import requests
7 | import pandas as pd
8 | from prefect import task, flow
9 | from prefect.task_runners import SequentialTaskRunner
10 |
11 | from src.feature_store_api import save_data_to_offline_feature_group
12 | from src import config
13 | from src.logger import get_console_logger
14 | # from src.dataflow import run_dataflow_in_backfill_mode
15 | from src.technical_indicators_pipeline import run_dataflow_in_backfill_mode
16 |
17 | logger = get_console_logger()
18 |
19 | PAIR = 'xbtusd'
20 | OHLC_FREQUENCY_IN_MINUTES = 1
21 |
22 | @task(retries=3, retry_delay_seconds=60)
23 | def fetch_data_from_kraken_api(
24 | product_id: str,
25 | since_nano_seconds: int
26 | ) -> List[List]:
27 | """
28 | Fetches data from Kraken API for the given `pair` from `since_nano_seconds`
29 |
30 | Args:
31 | pair (str): currency pair we fetch data for
32 | last_ts (int): timestamp in seconds
33 |
34 | Returns:
35 | dict: response from Kraken API
36 | """
37 | MAP_PRODUCT_ID_TO_HISTORICAL_API_PAIR_NAME = {
38 | 'XBT/USD': 'xbtusd',
39 | }
40 | pair = MAP_PRODUCT_ID_TO_HISTORICAL_API_PAIR_NAME[product_id]
41 |
42 | # build URL we need to fetch data from
43 | url = f"https://api.kraken.com/0/public/Trades?pair={pair}&since={since_nano_seconds:.0f}"
44 |
45 | # fetch data
46 | response = requests.get(url)
47 |
48 | # extract data from response
49 | MAP_PRODUCT_ID_TO_HISTORIAL_API_RESPONSE_KEY = {
50 | 'XBT/USD': 'XXBTZUSD'
51 | }
52 | response_key = MAP_PRODUCT_ID_TO_HISTORIAL_API_RESPONSE_KEY[product_id]
53 | trade_params = response.json()["result"][response_key]
54 |
55 | return trade_params
56 |
57 | @flow
58 | def fetch_historical_data_one_day(day: datetime, product_id: str) -> pd.DataFrame:
59 | """"""
60 | # time range we want to fetch data for
61 | from_ts = day.timestamp()
62 | to_ts = (day + timedelta(days=1)).timestamp()
63 | # to_ts = (day + timedelta(hours=1)).timestamp()
64 |
65 | trades = []
66 | last_ts = from_ts
67 | while last_ts < to_ts:
68 |
69 | # fetch data from Kraken API
70 | trade_params = fetch_data_from_kraken_api(
71 | product_id,
72 | since_nano_seconds=last_ts*1000000000
73 | )
74 |
75 | # create a list of Dict with the results
76 | trades_in_batch = [
77 | {
78 | 'price': params[0],
79 | 'volume': params[1],
80 | 'timestamp': params[2],
81 | 'product_id': product_id,
82 | }
83 | for params in trade_params
84 | ]
85 |
86 | if len(trades_in_batch) == 0:
87 | logger.info(f'No more data for {product_id} from {datetime.utcfromtimestamp(last_ts)}')
88 | break
89 |
90 | # checking timestamps
91 | from_date = datetime.utcfromtimestamp(trades_in_batch[0]['timestamp'])
92 | to_date = datetime.utcfromtimestamp(trades_in_batch[-1]['timestamp'])
93 | logger.info(f'trades_in_batch from {from_date} to {to_date}')
94 |
95 | # add trades to list of trades
96 | trades += trades_in_batch
97 |
98 | # update last_ts for the next iteration
99 | last_ts = trades[-1]['timestamp']
100 |
101 | # TODO: remove this break. It is here to speed up the debugging process
102 | # break
103 |
104 | # drop trades that might fall outside of the time window
105 | trades = [t for t in trades if t['timestamp'] >= from_ts and t['timestamp'] <= to_ts]
106 | # logger.info(f'Fetched trade data from {trades[0]["timestamp"]} to {trades[-1]["timestamp"]}')
107 | logger.info(f'{len(trades)=}')
108 |
109 | # # convert trades to pandas dataframe
110 | trades = pd.DataFrame(trades)
111 |
112 | # # set correct dtypes
113 | trades['price'] = trades['price'].astype(float)
114 | trades['volume'] = trades['volume'].astype(float)
115 | trades['timestamp'] = trades['timestamp'].astype(int)
116 | trades['product_id'] = trades['product_id'].astype(str)
117 |
118 | return trades
119 |
120 | def save_data_to_csv_file(data: pd.DataFrame) -> str:
121 | """Saves data to a temporary csv file"""
122 | import time
123 | tmp_file_path = f'/tmp/{int(time.time())}.csv'
124 | data.to_csv(tmp_file_path)
125 | return tmp_file_path
126 |
127 | @task
128 | def transform_trades_to_ohlc(trades: pd.DataFrame) -> pd.DataFrame:
129 | """Transforms raw trades to OHLC data"""
130 |
131 | # convert ts column to pd.Timestamp
132 | # trades.index = pd.to_datetime(trades['ts'], unit='s')
133 |
134 | # logger.info('Saving trades to temporary file')
135 | file_path = save_data_to_csv_file(trades)
136 | logger.info(f'Saved trades to temporary file {file_path}')
137 |
138 | # breakpoint()
139 |
140 | logger.info('Creating dataflow for backfilling and running it')
141 | # from src.dataflow import run_dataflow_in_backfill_mode
142 | ohlc = run_dataflow_in_backfill_mode(input_file=file_path)
143 |
144 | # breakpoint()
145 |
146 | # remove temporary file
147 | logger.info(f'Removing temporary file {file_path}')
148 | import os
149 | os.remove(file_path)
150 |
151 | return ohlc
152 |
153 | @task(retries=3, retry_delay_seconds=60)
154 | def save_ohlc_to_feature_store(ohlc: pd.DataFrame) -> None:
155 | """Saves OHLC data to the feature store"""
156 | # save OHLC data to offline feature store
157 | logger.info('Saving OHLC data to offline feature store')
158 | save_data_to_offline_feature_group(
159 | data=ohlc,
160 | feature_group_config=config.FEATURE_GROUP_METADATA
161 | )
162 |
163 | @flow(task_runner=SequentialTaskRunner())
164 | def backfill_one_day(day: datetime, product_id: str):
165 | """Backfills OHLC data in the feature store for the given `day`"""
166 |
167 | # fetch trade data from external API
168 | trades: pd.DataFrame = fetch_historical_data_one_day(day, product_id)
169 |
170 | # transform trade data to OHLC data
171 | ohlc: pd.DataFrame = transform_trades_to_ohlc(trades)
172 |
173 | # push OHLC data to the feature store
174 | logger.info(f'Pushing OHLC data to offline feature group, {day=} {product_id=}')
175 | save_ohlc_to_feature_store(ohlc)
176 |
177 | @flow
178 | def backfill_range_dates(from_day: datetime, to_day: datetime, product_id: str):
179 | """
180 | Backfills OHLC data in the feature store for the ranges of days
181 | between `from_day` and `to_day`
182 | """
183 | days = pd.date_range(from_day, to_day, freq='D')
184 | for day in days:
185 | backfill_one_day(day, product_id=product_id)
186 |
187 |
188 | if __name__ == "__main__":
189 |
190 | parser = ArgumentParser()
191 | parser.add_argument(
192 | '--from_day',
193 | type=lambda s: datetime.strptime(s, '%Y-%m-%d').replace(tzinfo=pytz.UTC)
194 | )
195 | parser.add_argument(
196 | '--to_day',
197 | type=lambda s: datetime.strptime(s, '%Y-%m-%d').replace(tzinfo=pytz.UTC),
198 | default=datetime.utcnow().replace(tzinfo=pytz.UTC)
199 | )
200 | parser.add_argument(
201 | '--day',
202 | type=lambda s: datetime.strptime(s, '%Y-%m-%d').replace(tzinfo=pytz.UTC)
203 | )
204 |
205 | parser.add_argument(
206 | '--product_id', type=str,
207 | help="For example: XBT/USD",
208 | )
209 |
210 | args = parser.parse_args()
211 |
212 | if args.day:
213 | logger.info(f'Starting backfilling process for {args.day}')
214 | backfill_one_day(day=args.day, product_id=args.product_id)
215 |
216 | else:
217 | assert args.from_day <= args.to_day, "from_day has to be smaller or equal than to_day"
218 | logger.info(f'Starting backfilling process from {args.from_day} to {args.to_day}')
219 | backfill_range_dates(from_day=args.from_day, to_day=args.to_day, product_id=args.product_id)
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | from src.feature_store_api import FeatureGroupConfig, FeatureViewConfig
2 |
3 | WINDOW_SECONDS = 60
4 | PRODUCT_IDS = [
5 | "BTC/USD",
6 | # "ETH/USD",
7 | ]
8 |
9 | FEATURE_GROUP_METADATA = FeatureGroupConfig(
10 | name=f'ohlc_{WINDOW_SECONDS}_sec',
11 | version=1,
12 | description=f"OHLC data with technical indicators every {WINDOW_SECONDS} seconds",
13 | primary_key=[ 'product_id', 'time'],
14 | event_time='time',
15 | online_enabled=True,
16 | )
17 |
18 | FEATURE_VIEW_CONFIG = FeatureViewConfig(
19 | name=f'ohlc_{WINDOW_SECONDS}_sec_view',
20 | version=1,
21 | feature_group=FEATURE_GROUP_METADATA,
22 | )
--------------------------------------------------------------------------------
/src/date_utils.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timezone
2 |
3 | def str2epoch(x: str) -> int:
4 | return str2datetime(x).timestamp()
5 |
6 | def str2datetime(s: str) -> datetime:
7 | return datetime.strptime(s, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc)
8 |
9 | def epoch2datetime(epoch: int) -> datetime:
10 | return datetime.utcfromtimestamp(epoch).replace(tzinfo=timezone.utc)
--------------------------------------------------------------------------------
/src/feature_store_api.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 | from dataclasses import dataclass
4 |
5 | import hsfs
6 | import hopsworks
7 | import pandas as pd
8 |
9 | from src.logger import get_console_logger
10 |
11 | logger = get_console_logger()
12 |
13 | # load Feature Store credentials from environment variables
14 | HOPSWORKS_PROJECT_NAME = os.environ['HOPSWORKS_PROJECT_NAME']
15 | HOPSWORKS_API_KEY = os.environ['HOPSWORKS_API_KEY']
16 |
17 | @dataclass
18 | class FeatureGroupConfig:
19 | name: str
20 | version: int
21 | description: str
22 | primary_key: List[str]
23 | event_time: str
24 | online_enabled: bool
25 |
26 | @dataclass
27 | class FeatureViewConfig:
28 | name: str
29 | version: int
30 | feature_group: FeatureGroupConfig
31 |
32 |
33 | def get_feature_store() -> hsfs.feature_store.FeatureStore:
34 | """Connects to Hopsworks and returns a pointer to the feature store
35 |
36 | Returns:
37 | hsfs.feature_store.FeatureStore: pointer to the feature store
38 | """
39 | project = hopsworks.login(
40 | project=HOPSWORKS_PROJECT_NAME,
41 | api_key_value=HOPSWORKS_API_KEY
42 | )
43 | return project.get_feature_store()
44 |
45 |
46 | def get_feature_group(
47 | feature_group_metadata: FeatureGroupConfig
48 | ) -> hsfs.feature_group.FeatureGroup:
49 | """Connects to the feature store and returns a pointer to the given
50 | feature group `name`
51 |
52 | Args:
53 | name (str): name of the feature group
54 | version (Optional[int], optional): _description_. Defaults to 1.
55 |
56 | Returns:
57 | hsfs.feature_group.FeatureGroup: pointer to the feature group
58 | """
59 | return get_feature_store().get_or_create_feature_group(
60 | name=feature_group_metadata.name,
61 | version=feature_group_metadata.version,
62 | description=feature_group_metadata.description,
63 | primary_key=feature_group_metadata.primary_key,
64 | event_time=feature_group_metadata.event_time,
65 | online_enabled=feature_group_metadata.online_enabled
66 | )
67 |
68 |
69 | def get_or_create_feature_view(
70 | feature_view_metadata: FeatureViewConfig
71 | ) -> hsfs.feature_view.FeatureView:
72 | """"""
73 |
74 | # get pointer to the feature store
75 | feature_store = get_feature_store()
76 |
77 | # get pointer to the feature group
78 | # from src.config import FEATURE_GROUP_METADATA
79 | feature_group = feature_store.get_feature_group(
80 | name=feature_view_metadata.feature_group.name,
81 | version=feature_view_metadata.feature_group.version
82 | )
83 |
84 | # create feature view if it doesn't exist
85 | try:
86 | feature_store.create_feature_view(
87 | name=feature_view_metadata.name,
88 | version=feature_view_metadata.version,
89 | query=feature_group.select_all()
90 | )
91 | except:
92 | # logger.info("Feature view already exists, skipping creation.")
93 | logger.info("Feature view already exists, skipping creation.")
94 |
95 | # get feature view
96 | feature_store = get_feature_store()
97 | feature_view = feature_store.get_feature_view(
98 | name=feature_view_metadata.name,
99 | version=feature_view_metadata.version,
100 | )
101 |
102 | return feature_view
103 |
104 | # from prefect import task
105 | # @task(retries=3, retry_delay_seconds=60)
106 | def save_data_to_offline_feature_group(
107 | data: pd.DataFrame,
108 | feature_group_config: FeatureGroupConfig
109 | ) -> None:
110 | """"""
111 | feature_group = get_feature_group(feature_group_config)
112 | feature_group.insert(data, write_options={"wait_for_job": False})
--------------------------------------------------------------------------------
/src/flow_steps.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Dict, List, Generator, Any, Optional, Callable
2 | import json
3 | from datetime import datetime, timedelta, timezone
4 | from pathlib import Path
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from websocket import create_connection
9 | from bytewax.inputs import DynamicInput, StatelessSource
10 | from bytewax.connectors.files import FileInput, CSVInput
11 | from bytewax.dataflow import Dataflow
12 | from bytewax.window import EventClockConfig, TumblingWindow #Config
13 | from bytewax.outputs import StatelessSink, DynamicOutput
14 |
15 | from src.config import PRODUCT_IDS, WINDOW_SECONDS
16 | from src.feature_store_api import FeatureGroupConfig, get_feature_group
17 | from src.date_utils import epoch2datetime
18 | from src.logger import get_console_logger
19 |
20 | logger = get_console_logger()
21 |
22 | def connect_to_input_csv_file(flow: Dataflow, input_file: Path) -> Dataflow:
23 |
24 | # connect to input file
25 | flow.input("input", CSVInput(input_file))
26 |
27 | # extract product_id as key and trades as list of dicts
28 | def extract_key_and_trades(data: Dict) -> Tuple[str, List[Dict]]:
29 | """"""
30 | product_id = data['product_id']
31 | trades = [
32 | {
33 | 'price': float(data['price']),
34 | 'volume': float(data['volume']),
35 | 'timestamp': float(data['timestamp']),
36 | }
37 | ]
38 | return (product_id, trades)
39 |
40 | flow.map(extract_key_and_trades)
41 |
42 | return flow
43 |
44 | def connect_to_kraken_websocket(flow: Dataflow) -> Dataflow:
45 | """"""
46 | connect_to_input_socket(flow)
47 | format_websocket_event(flow)
48 |
49 | return flow
50 |
51 | def connect_to_input_socket(flow: Dataflow) -> Dataflow:
52 | """Connects the given dataflow to the Coinbase websocket
53 |
54 | Args:
55 | flow (Dataflow): _description_
56 | """
57 | class KrakenSource(StatelessSource):
58 | def __init__(self, product_ids):
59 |
60 | self.product_ids = product_ids
61 |
62 | self.ws = create_connection("wss://ws.kraken.com/")
63 | self.ws.send(
64 | json.dumps(
65 | {
66 | "event": "subscribe",
67 | "subscription": {"name":"trade"},
68 | "pair": product_ids,
69 | }
70 | )
71 | )
72 | # The first msg is just a confirmation that we have subscribed.
73 | logger.info(f'First message from websocket: {self.ws.recv()}')
74 |
75 | def next(self) -> Optional[Any]:
76 |
77 | event = self.ws.recv()
78 |
79 | if 'heartbeat' in event:
80 | # logger.info(f'Heartbeat event: {event}')
81 | return None
82 |
83 | if 'channelID' in event:
84 | # logger.info(f'Subscription event: {event}')
85 | return None
86 |
87 | # logger.info(f'{event=}')
88 | return event
89 |
90 | class KrakenInput(DynamicInput):
91 |
92 | def build(self, worker_index, worker_count):
93 | prods_per_worker = int(len(PRODUCT_IDS) / worker_count)
94 | product_ids = PRODUCT_IDS[
95 | int(worker_index * prods_per_worker) : int(
96 | worker_index * prods_per_worker + prods_per_worker
97 | )
98 | ]
99 | return KrakenSource(product_ids)
100 |
101 | flow.input("input", KrakenInput())
102 |
103 | return flow
104 |
105 | def format_websocket_event(flow: Dataflow) -> Dataflow:
106 |
107 | # string to json
108 | flow.map(json.loads)
109 |
110 | # extract product_id as key and trades as list of dicts
111 | def extract_key_and_trades(data: Dict) -> Tuple[str, List[Dict]]:
112 | """"""
113 | product_id = data[3]
114 | trades = [
115 | {
116 | 'price': float(t[0]),
117 | 'volume': float(t[1]),
118 | 'timestamp': float(t[2]),
119 | }
120 | for t in data[1]
121 | ]
122 | return (product_id, trades)
123 |
124 | flow.map(extract_key_and_trades)
125 |
126 | return flow
127 |
128 | def add_tumbling_window(flow: Dataflow, window_seconds: int) -> Dataflow:
129 |
130 | def get_event_time(trades: List[Dict]) -> datetime:
131 | """
132 | This function instructs the event clock on how to retrieve the
133 | event's datetime from the input.
134 | """
135 | # use timestamp from the first trade in the list
136 | return epoch2datetime(trades[0]['timestamp'])
137 |
138 | def build_array() -> np.array:
139 | """_summary_
140 |
141 | Returns:
142 | np.array: _description_
143 | """
144 | return np.empty((0,3))
145 |
146 | def acc_values(previous_data: np.array, trades: List[Dict]) -> np.array:
147 | """
148 | This is the accumulator function, and outputs a numpy array of time and price
149 | """
150 | # TODO: fix this to add all trades, not just the first in the event
151 | return np.insert(previous_data, 0,
152 | np.array((trades[0]['timestamp'], trades[0]['price'], trades[0]['volume'])), 0)
153 |
154 |
155 | # Configure the `fold_window` operator to use the event time
156 | clock_config = EventClockConfig(
157 | get_event_time,
158 | wait_for_system_duration=timedelta(seconds=0)
159 | )
160 |
161 | window_config = TumblingWindow(length=timedelta(seconds=window_seconds),
162 | align_to=datetime(2023, 4, 3, 0, 0, 0, tzinfo=timezone.utc))
163 |
164 | flow.fold_window(f"{window_seconds}_sec",
165 | clock_config,
166 | window_config,
167 | build_array,
168 | acc_values)
169 | return flow
170 |
171 | def aggregate_raw_trades_as_ohlc(flow: Dataflow) -> Dataflow:
172 |
173 | # compute OHLC for the window
174 | def calculate_features(ticker_data: Tuple[str, np.array]) -> Tuple[str, Dict]:
175 | """Aggregate trade data in window
176 |
177 | Args:
178 | ticker__data (Tuple[str, np.array]): product_id, data
179 |
180 | Returns:
181 | Tuple[str, Dict]: product_id, Dict with keys
182 | - time
183 | - open
184 | - high
185 | - low
186 | - close
187 | - volume
188 | """
189 | def round_timestamp(timestamp_seconds: int) -> int:
190 | import math
191 | return int(math.floor(timestamp_seconds / WINDOW_SECONDS * 1.0)) * WINDOW_SECONDS
192 |
193 | ticker, data = ticker_data
194 | ohlc = {
195 | "time": round_timestamp(data[-1][0]),
196 | # "time": data[-1][0],
197 | "open": data[:,1][-1],
198 | "high": np.amax(data[:,1]),
199 | "low": np.amin(data[:,1]),
200 | "close": data[:,1][0],
201 | "volume": np.sum(data[:,2])
202 | }
203 | return (ticker, ohlc)
204 |
205 | flow.map(calculate_features)
206 | return flow
207 |
208 | def tuple_to_dict(flow: Dataflow) -> Dataflow:
209 |
210 | def _tuple_to_dict(key__dict: Tuple[str, Dict]) -> Dict:
211 | key, dict = key__dict
212 | dict['product_id'] = key
213 |
214 | # TODO: fix this upstream
215 | dict['time'] = int(dict['time'])
216 |
217 | return dict
218 |
219 | flow.map(_tuple_to_dict)
220 | return flow
221 |
222 | def save_output_to_list(flow: Dataflow, output: List[Any]) -> Dataflow:
223 | """Saves output to a list"""
224 | from bytewax.testing import TestingOutput
225 | flow.output("output", TestingOutput(output))
226 | return flow
227 |
228 | def save_output_to_online_feature_store(
229 | flow: Dataflow,
230 | feature_group_metadata: FeatureGroupConfig
231 | ) -> Dataflow:
232 | class HopsworksSink(StatelessSink):
233 | def __init__(self, feature_group):
234 | self.feature_group = feature_group
235 |
236 | def write(self, item):
237 | df = pd.DataFrame.from_records([item])
238 | logger.info(f'Saving {item} to online feature store...')
239 | self.feature_group.insert(df, write_options={"start_offline_backfill": False})
240 |
241 | class HopsworksOutput(DynamicOutput):
242 | def __init__(self, feature_group_metadata):
243 | self.feature_group = get_feature_group(feature_group_metadata)
244 |
245 | def build(self, worker_index, worker_count):
246 | return HopsworksSink(self.feature_group)
247 |
248 | flow.output("hopsworks", HopsworksOutput(feature_group_metadata))
249 | return flow
--------------------------------------------------------------------------------
/src/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Optional
3 |
4 | def get_console_logger(name: Optional[str] = 'trading_bot_gpt') -> logging.Logger:
5 |
6 | # Create logger if it doesn't exist
7 | logger = logging.getLogger(name)
8 | logger.setLevel(logging.DEBUG)
9 |
10 | return logger
--------------------------------------------------------------------------------
/src/old/01_basic_llm_chain.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from langchain.llms import OpenAI
4 | from langchain import PromptTemplate, LLMChain
5 | from dotenv import load_dotenv
6 | from comet_llm import Span, end_chain, start_chain
7 |
8 | load_dotenv()
9 |
10 | OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
11 |
12 | def get_llm_chain() -> LLMChain:
13 |
14 | # prompte template
15 | template = """Question: {question}
16 | Answer: Let's think step by step."""
17 |
18 | prompt = PromptTemplate(template=template, input_variables=["question"])
19 |
20 | # llm
21 | llm = OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0)
22 |
23 | # llm chain
24 | llm_chain = LLMChain(prompt=prompt, llm=llm)
25 |
26 | return llm_chain
27 |
28 | llm_chain = get_llm_chain()
29 |
30 | def main(question: str):
31 |
32 | start_chain(inputs={"question": question})
33 |
34 | with Span(
35 | category="llm-reasoning",
36 | inputs={"question": question},
37 | ) as span:
38 | response = llm_chain.run(question)
39 | print(response)
40 | span.set_outputs(outputs={"response": response})
41 |
42 | end_chain(outputs={"response": response})
43 |
44 | question = "What NFL team won the Super Bowl in the year Justin Beiber was born?"
45 | main(question)
--------------------------------------------------------------------------------
/src/old/02_trading_bot_fake_context.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, List
3 |
4 | from langchain.llms import OpenAI
5 | from langchain import PromptTemplate, LLMChain
6 | from dotenv import load_dotenv
7 | from comet_llm import Span, end_chain, start_chain
8 |
9 | load_dotenv()
10 |
11 | OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
12 |
13 | def retrieve_technical_indicators() -> Dict[str, float]:
14 | """Retrieve technical indicators from a database."""
15 | return {'RSI': 0.5, 'MACD': 0.5, 'SMA': 0.5}
16 |
17 |
18 | def retrieve_recent_news() -> List[str]:
19 | """Retrieve recent news from a database."""
20 | return [
21 | 'ETH is about to upgrade its protocol',
22 | 'DeFI sector is growing faster than expected.'
23 | ]
24 |
25 | def get_llm_chain() -> LLMChain:
26 |
27 | # prompte template
28 | template = """
29 | I will give you the current technical indicators for ETH/USD and some recent financial news and you wil tell me if the asset price will go up or down in the next 1 hour.
30 |
31 | Dont tell me that predicting short-term price movements is highly uncertain please, but be brave!
32 |
33 | Format the output as a python dictionary with two keys: signal, and explanation. Signal is either -1 (meaning price goes down) or 1 (meaning price goes up). Explanation is a string exposing the reasoning behind your prediction. Please provide no more text in your response apart from this python dictionary.
34 |
35 | # technical indicators
36 | {technical_indicators}
37 |
38 | # news
39 | {recent_news}
40 | """
41 | prompt = PromptTemplate(template=template,
42 | input_variables=["technical_indicators", "recent_news"])
43 |
44 | # llm
45 | llm = OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0)
46 |
47 | # llm chain
48 | llm_chain = LLMChain(prompt=prompt, llm=llm)
49 |
50 | return llm_chain
51 |
52 | def main(current_time: int):
53 |
54 | start_chain(inputs={"current_time": current_time})
55 |
56 | # retrieve current technical indicators
57 | with Span(
58 | category="context-retrieval",
59 | name="Retrieve technical indicators",
60 | inputs={"current_time": current_time},
61 | ) as span:
62 | technical_indicators = retrieve_technical_indicators()
63 | span.set_outputs(outputs={"technical_indicators": technical_indicators})
64 |
65 | # retrieve recent news
66 | with Span(
67 | category="context-retrieval",
68 | name="Retrieve recent news",
69 | inputs={"current_time": current_time},
70 | ) as span:
71 | recent_news = retrieve_recent_news()
72 | span.set_outputs(outputs={"recent_news": recent_news})
73 |
74 | # llm that will reason about the context (technical indicators + recent news)
75 | # and output a prediction
76 | llm_chain = get_llm_chain()
77 |
78 | with Span(
79 | category="llm-reasoning",
80 | inputs={"technical_indicators": technical_indicators,"recent_news": recent_news},
81 | ) as span:
82 | response = llm_chain.run(technical_indicators=technical_indicators, recent_news=recent_news)
83 | print(response)
84 | span.set_outputs(outputs={"response": response})
85 |
86 | end_chain(outputs={"response": response})
87 |
88 | from datetime import datetime
89 |
90 | current_time = int(datetime.utcnow().timestamp())
91 | main(current_time)
--------------------------------------------------------------------------------
/src/technical_indicators_pipeline.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional, List, Any
3 |
4 | import pandas as pd
5 | from bytewax.dataflow import Dataflow
6 | from bytewax.connectors.stdio import StdOutput
7 | from bytewax.testing import run_main
8 |
9 | from src import config
10 | from src.flow_steps import (
11 | connect_to_kraken_websocket,
12 | connect_to_input_csv_file,
13 | add_tumbling_window,
14 | aggregate_raw_trades_as_ohlc,
15 | tuple_to_dict,
16 | save_output_to_online_feature_store,
17 | save_output_to_list,
18 | )
19 | from src.logger import get_console_logger
20 |
21 | logger = get_console_logger()
22 |
23 |
24 | def run_dataflow_in_backfill_mode(input_file: str) -> pd.DataFrame:
25 | """"""
26 | logger.info('Building dataflow in BACKFILL mode')
27 | ohlc = []
28 |
29 | flow = get_dataflow(
30 | execution_mode='BACKFILL',
31 | backfill_input_file=input_file,
32 | backfill_output=ohlc
33 | )
34 |
35 | logger.info('Running dataflow in BACKFILL mode')
36 | run_main(flow)
37 |
38 | return pd.DataFrame(ohlc)
39 |
40 |
41 | def get_dataflow(
42 | execution_mode: Optional[str] = 'LIVE',
43 | backfill_input_file: Optional[Path] = None,
44 | backfill_output: List[Any] = None,
45 | ) -> Dataflow:
46 | """Constructs and returns a ByteWax Dataflow
47 |
48 | Args:
49 | window_seconds (int)
50 |
51 | Returns:
52 | Dataflow:
53 | """
54 | window_seconds = config.WINDOW_SECONDS
55 | assert execution_mode in ['LIVE', 'BACKFILL', 'DEBUG']
56 |
57 | flow = Dataflow()
58 |
59 | # connect dataflow to an input, either websocket or local file
60 | if execution_mode == 'BACKFILL':
61 | logger.info('BACKFILL MODE ON!')
62 | logger.info('Adding connector to local file to backfill')
63 |
64 | if not backfill_input_file:
65 | raise Exception('You need to provide `backfill_input_file` when your execution_mode is BACKFILL')
66 |
67 | connect_to_input_csv_file(flow, backfill_input_file)
68 | else:
69 | logger.info('Adding connector to input socket')
70 | connect_to_kraken_websocket(flow)
71 |
72 | # add tumbling window
73 | add_tumbling_window(flow, window_seconds)
74 |
75 | # reduce multiple trades into a single OHLC
76 | aggregate_raw_trades_as_ohlc(flow)
77 |
78 | # add technical indicators
79 | # TODO
80 |
81 | tuple_to_dict(flow)
82 |
83 | # send output to stdout, online feature store or in-memory list
84 | if execution_mode == 'DEBUG':
85 | logger.info('Dataflow in debug mode. Output to stdout')
86 | flow.output("output", StdOutput())
87 |
88 | elif execution_mode == 'LIVE':
89 | from src.config import FEATURE_GROUP_METADATA
90 | save_output_to_online_feature_store(flow, FEATURE_GROUP_METADATA)
91 |
92 | elif execution_mode == 'BACKFILL':
93 | logger.info('Pushing data to Offline Feature Store')
94 | save_output_to_list(flow, backfill_output)
95 |
96 | return flow
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Paulescu/trading-bot-gpt/2bf95306d6b0488de7bfdff5da784113eacf2b68/tests/__init__.py
--------------------------------------------------------------------------------