├── img
└── logo.png
├── prediction_chart.png
├── model
├── __pycache__
│ ├── kronos.cpython-39.pyc
│ ├── module.cpython-39.pyc
│ └── __init__.cpython-39.pyc
├── __init__.py
├── kronos.py
└── module.py
├── index.html
├── style.css
└── update_predictions.py
/img/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyu-coder/Kronos-demo/HEAD/img/logo.png
--------------------------------------------------------------------------------
/prediction_chart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyu-coder/Kronos-demo/HEAD/prediction_chart.png
--------------------------------------------------------------------------------
/model/__pycache__/kronos.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyu-coder/Kronos-demo/HEAD/model/__pycache__/kronos.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/module.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyu-coder/Kronos-demo/HEAD/model/__pycache__/module.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shiyu-coder/Kronos-demo/HEAD/model/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .kronos import KronosTokenizer, Kronos, KronosPredictor
2 |
3 | model_dict = {
4 | 'kronos_tokenizer': KronosTokenizer,
5 | 'kronos': Kronos,
6 | 'kronos_predictor': KronosPredictor
7 | }
8 |
9 |
10 | def get_model_class(model_name):
11 | if model_name in model_dict:
12 | return model_dict[model_name]
13 | else:
14 | print(f"Model {model_name} not found in model_dict")
15 | raise NotImplementedError
16 |
17 |
18 |
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Kronos Live Forecast | BTC/USDT
7 |
8 |
9 |
10 |
11 |
12 |
19 |
20 |
21 |
22 |
23 | Live BTC/USDT Forecast Dashboard
24 |
25 | Last Updated (UTC): 2025-12-14 19:00:43
26 | Data Source: Binance | Interval: 1-Hour
27 |
28 |
29 |
30 |
31 |
Upside Probability (Next 24h)
32 |
50.0%
33 |
The model's confidence that the price in 24 hours will be higher than the last known price.
34 |
35 |
36 |
Volatility Amplification (Next 24h)
37 |
90.0%
38 |
The probability that predicted volatility over the next 24h will exceed recent historical volatility.
39 |
40 |
41 |
42 |
43 |
24-Hour Probabilistic Forecast
44 |
The chart below shows the historical price (blue) and the probabilistic forecast (orange). The orange line is the mean of multiple Monte Carlo simulations, and the shaded area represents the full range of predicted outcomes, indicating forecast uncertainty.
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 | Methodology Overview
54 | This demo showcases the forecasting results of Kronos , a foundation model pre-trained on the "language" of financial markets. The predictions are generated using the following process:
55 |
56 | Model: The `Kronos-mini` (4M parameters) model is used to autoregressively predict future K-line data.
57 | Data Context: The model uses the last 360 hours (~15 days) of BTC/USDT 1h K-line data from Binance as context for each new prediction.
58 | Probabilistic Forecasting: We employ Monte Carlo sampling (N=30 paths) to generate a distribution of possible future price trajectories, not just a single point forecast.
59 | Derived Insights: The resulting distribution is analyzed to produce the mean forecast (solid line), the uncertainty range (shaded area), and the key probability metrics shown above.
60 |
61 |
62 |
63 |
64 |
65 | About The Kronos Project
66 | Kronos is the first open-source foundation model for financial candlesticks (K-lines), trained on data from over 45 global exchanges. It is designed to serve as a unified model for diverse quantitative finance tasks.
67 |
71 |
72 |
73 |
74 |
79 |
80 |
--------------------------------------------------------------------------------
/style.css:
--------------------------------------------------------------------------------
1 | /* --- Global Styles & Variables --- */
2 | :root {
3 | --primary-color: #3498db;
4 | --secondary-color: #2c3e50;
5 | --light-color: #ecf0f1;
6 | --dark-color: #34495e;
7 | --background-color: #f4f7f6;
8 | --card-bg-color: #ffffff;
9 | --text-color: #333;
10 | --subtle-text-color: #6c757d;
11 | --border-color: #e9ecef;
12 | }
13 |
14 | body {
15 | font-family: 'Inter', 'Lato', sans-serif;
16 | line-height: 1.7;
17 | margin: 0;
18 | padding: 0;
19 | background-color: var(--background-color);
20 | color: var(--text-color);
21 | }
22 |
23 | .container {
24 | max-width: 960px;
25 | margin: auto;
26 | padding: 0 1.5rem;
27 | }
28 |
29 | h1, h2, h3 {
30 | font-family: 'Roboto Slab', serif;
31 | color: var(--dark-color);
32 | line-height: 1.3;
33 | }
34 |
35 | h1 { font-size: 2.8rem; }
36 | h2 { font-size: 2.2rem; border-bottom: 3px solid var(--border-color); padding-bottom: 0.6rem; margin-bottom: 1.5rem; }
37 | h3 { font-size: 1.4rem; color: var(--secondary-color); }
38 |
39 | p {
40 | margin-bottom: 1rem;
41 | }
42 |
43 | /* --- Header --- */
44 | header {
45 | background: var(--secondary-color);
46 | color: var(--light-color);
47 | padding: 2.5rem 0;
48 | text-align: center;
49 | border-bottom: 5px solid var(--primary-color);
50 | }
51 |
52 | header .project-title {
53 | margin: 0 0 0.5rem 0;
54 | color: var(--light-color);
55 | font-size: 2.8rem;
56 | }
57 |
58 | header .subtitle {
59 | font-size: 1.25rem;
60 | color: #bdc3c7;
61 | margin-bottom: 2rem;
62 | font-style: italic;
63 | }
64 |
65 | /* UPDATED: Style for the GitHub link to ensure high visibility in all states */
66 | header .repo-link,
67 | header .repo-link:link,
68 | header .repo-link:visited {
69 | color: var(--light-color); /* Ensures normal and visited links are light-colored */
70 | background-color: transparent;
71 | border: 2px solid var(--primary-color);
72 | padding: 0.6rem 1.5rem;
73 | border-radius: 50px; /* Pill shape */
74 | text-decoration: none;
75 | font-weight: 600;
76 | transition: background-color 0.3s, color 0.3s;
77 | }
78 |
79 | header .repo-link:hover,
80 | header .repo-link:active {
81 | background-color: var(--primary-color);
82 | color: var(--card-bg-color);
83 | }
84 |
85 |
86 | /* --- Main Content & Sections --- */
87 | main {
88 | margin: 2.5rem auto;
89 | }
90 |
91 | section {
92 | background: var(--card-bg-color);
93 | margin-bottom: 2.5rem;
94 | padding: 2rem;
95 | border-radius: 12px;
96 | box-shadow: 0 4px 15px rgba(0,0,0,0.07);
97 | }
98 |
99 | /* --- Dashboard Section --- */
100 | #live-dashboard .metadata {
101 | display: flex;
102 | justify-content: space-between;
103 | flex-wrap: wrap;
104 | font-style: italic;
105 | color: var(--subtle-text-color);
106 | margin-bottom: 2rem;
107 | padding-bottom: 1rem;
108 | border-bottom: 1px solid var(--border-color);
109 | }
110 |
111 | .metrics-container {
112 | display: flex;
113 | justify-content: space-around;
114 | gap: 1.5rem;
115 | text-align: center;
116 | margin-bottom: 2rem;
117 | }
118 |
119 | .metric-card {
120 | flex: 1;
121 | min-width: 220px;
122 | padding: 1.5rem;
123 | background-color: #f8f9fa;
124 | border-radius: 8px;
125 | border: 1px solid var(--border-color);
126 | transition: transform 0.2s, box-shadow 0.2s;
127 | }
128 |
129 | .metric-card:hover {
130 | transform: translateY(-5px);
131 | box-shadow: 0 6px 20px rgba(0,0,0,0.1);
132 | }
133 |
134 | .metric-card h3 {
135 | margin-top: 0;
136 | font-size: 1.1rem;
137 | }
138 |
139 | .metric-value {
140 | font-size: 3rem;
141 | font-weight: 700;
142 | color: var(--primary-color);
143 | margin: 0.5rem 0;
144 | }
145 |
146 | .metric-desc {
147 | font-size: 0.9rem;
148 | color: var(--subtle-text-color);
149 | }
150 |
151 | .chart-wrapper p {
152 | max-width: 80ch;
153 | }
154 |
155 | .chart-container {
156 | width: 100%;
157 | text-align: center;
158 | margin-top: 1rem;
159 | }
160 |
161 | .chart-img {
162 | max-width: 100%;
163 | height: auto;
164 | border: 1px solid var(--border-color);
165 | border-radius: 8px;
166 | }
167 |
168 | /* --- Methodology & About Sections --- */
169 | #methodology ul {
170 | list-style-type: disc;
171 | padding-left: 20px;
172 | }
173 |
174 | #methodology li {
175 | margin-bottom: 0.75rem;
176 | }
177 |
178 | .project-links {
179 | margin-top: 1.5rem;
180 | display: flex;
181 | gap: 1rem;
182 | }
183 |
184 | .btn {
185 | text-decoration: none;
186 | padding: 0.8rem 1.8rem;
187 | border-radius: 5px;
188 | font-weight: 600;
189 | transition: all 0.3s ease;
190 | border: none;
191 | }
192 | .btn-primary {
193 | background-color: var(--primary-color);
194 | color: white;
195 | }
196 | .btn-primary:hover {
197 | background-color: #2980b9;
198 | }
199 | .btn-secondary {
200 | background-color: var(--border-color);
201 | color: var(--dark-color);
202 | }
203 | .btn-secondary:hover {
204 | background-color: #dcdfe2;
205 | }
206 | .btn.disabled {
207 | background-color: #e0e0e0;
208 | color: #a0a0a0;
209 | cursor: not-allowed;
210 | pointer-events: none;
211 | }
212 |
213 |
214 | /* --- Footer --- */
215 | footer {
216 | text-align: center;
217 | padding: 2rem 0;
218 | background: var(--dark-color);
219 | color: var(--light-color);
220 | font-size: 0.9rem;
221 | }
222 |
223 | /* --- Responsive Design --- */
224 | @media (max-width: 768px) {
225 | header .project-title { font-size: 2.2rem; }
226 | h2 { font-size: 1.8rem; }
227 | .container { padding: 0 1rem; }
228 |
229 | .metrics-container {
230 | flex-direction: column;
231 | }
232 | #live-dashboard .metadata {
233 | flex-direction: column;
234 | align-items: flex-start;
235 | gap: 0.5rem;
236 | }
237 | }
--------------------------------------------------------------------------------
/update_predictions.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import re
4 | import subprocess
5 | import time
6 | from datetime import datetime, timezone, timedelta
7 | from pathlib import Path
8 |
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import pandas as pd
12 | import torch
13 | from binance.client import Client
14 |
15 | from model import KronosTokenizer, Kronos, KronosPredictor
16 |
17 | # --- Configuration ---
18 | Config = {
19 | "REPO_PATH": Path(__file__).parent.resolve(),
20 | "MODEL_PATH": "../Kronos_model",
21 | "SYMBOL": 'BTCUSDT',
22 | "INTERVAL": '1h',
23 | "HIST_POINTS": 360,
24 | "PRED_HORIZON": 24,
25 | "N_PREDICTIONS": 30,
26 | "VOL_WINDOW": 24,
27 | }
28 |
29 |
30 | def load_model():
31 | """Loads the Kronos model and tokenizer."""
32 | print("Loading Kronos model...")
33 | tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-2k", cache_dir=Config["MODEL_PATH"])
34 | model = Kronos.from_pretrained("NeoQuasar/Kronos-mini", cache_dir=Config["MODEL_PATH"])
35 | tokenizer.eval()
36 | model.eval()
37 | predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512)
38 | print("Model loaded successfully.")
39 | return predictor
40 |
41 |
42 | def make_prediction(df, predictor):
43 | """Generates probabilistic forecasts using the Kronos model."""
44 | last_timestamp = df['timestamps'].max()
45 | start_new_range = last_timestamp + pd.Timedelta(hours=1)
46 | new_timestamps_index = pd.date_range(
47 | start=start_new_range,
48 | periods=Config["PRED_HORIZON"],
49 | freq='H'
50 | )
51 | y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp')
52 | x_timestamp = df['timestamps']
53 | x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']]
54 |
55 | with torch.no_grad():
56 | print("Making main prediction (T=1.0)...")
57 | begin_time = time.time()
58 | close_preds_main, volume_preds_main = predictor.predict(
59 | df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp,
60 | pred_len=Config["PRED_HORIZON"], T=1.0, top_p=0.95,
61 | sample_count=Config["N_PREDICTIONS"], verbose=True
62 | )
63 | print(f"Main prediction completed in {time.time() - begin_time:.2f} seconds.")
64 |
65 | # print("Making volatility prediction (T=0.9)...")
66 | # begin_time = time.time()
67 | # close_preds_volatility, _ = predictor.predict(
68 | # df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp,
69 | # pred_len=Config["PRED_HORIZON"], T=0.9, top_p=0.9,
70 | # sample_count=Config["N_PREDICTIONS"], verbose=True
71 | # )
72 | # print(f"Volatility prediction completed in {time.time() - begin_time:.2f} seconds.")
73 | close_preds_volatility = close_preds_main
74 |
75 | return close_preds_main, volume_preds_main, close_preds_volatility
76 |
77 |
78 | def fetch_binance_data():
79 | """Fetches K-line data from the Binance public API."""
80 | symbol, interval = Config["SYMBOL"], Config["INTERVAL"]
81 | limit = Config["HIST_POINTS"] + Config["VOL_WINDOW"]
82 |
83 | print(f"Fetching {limit} bars of {symbol} {interval} data from Binance...")
84 | client = Client()
85 | klines = client.get_klines(symbol=symbol, interval=interval, limit=limit)
86 |
87 | cols = ['open_time', 'open', 'high', 'low', 'close', 'volume', 'close_time',
88 | 'quote_asset_volume', 'number_of_trades', 'taker_buy_base_asset_volume',
89 | 'taker_buy_quote_asset_volume', 'ignore']
90 | df = pd.DataFrame(klines, columns=cols)
91 |
92 | df = df[['open_time', 'open', 'high', 'low', 'close', 'volume', 'quote_asset_volume']]
93 | df.rename(columns={'quote_asset_volume': 'amount', 'open_time': 'timestamps'}, inplace=True)
94 |
95 | df['timestamps'] = pd.to_datetime(df['timestamps'], unit='ms')
96 | for col in ['open', 'high', 'low', 'close', 'volume', 'amount']:
97 | df[col] = pd.to_numeric(df[col])
98 |
99 | print("Data fetched successfully.")
100 | return df
101 |
102 |
103 | def calculate_metrics(hist_df, close_preds_df, v_close_preds_df):
104 | """
105 | Calculates upside and volatility amplification probabilities for the 24h horizon.
106 | """
107 | last_close = hist_df['close'].iloc[-1]
108 |
109 | # 1. Upside Probability (for the 24-hour horizon)
110 | # This is the probability that the price at the end of the horizon is higher than now.
111 | final_hour_preds = close_preds_df.iloc[-1]
112 | upside_prob = (final_hour_preds > last_close).mean()
113 |
114 | # 2. Volatility Amplification Probability (over the 24-hour horizon)
115 | hist_log_returns = np.log(hist_df['close'] / hist_df['close'].shift(1))
116 | historical_vol = hist_log_returns.iloc[-Config["VOL_WINDOW"]:].std()
117 |
118 | amplification_count = 0
119 | for col in v_close_preds_df.columns:
120 | full_sequence = pd.concat([pd.Series([last_close]), v_close_preds_df[col]]).reset_index(drop=True)
121 | pred_log_returns = np.log(full_sequence / full_sequence.shift(1))
122 | predicted_vol = pred_log_returns.std()
123 | if predicted_vol > historical_vol:
124 | amplification_count += 1
125 |
126 | vol_amp_prob = amplification_count / len(v_close_preds_df.columns)
127 |
128 | print(f"Upside Probability (24h): {upside_prob:.2%}, Volatility Amplification Probability: {vol_amp_prob:.2%}")
129 | return upside_prob, vol_amp_prob
130 |
131 |
132 | def create_plot(hist_df, close_preds_df, volume_preds_df):
133 | """Generates and saves a comprehensive forecast chart."""
134 | print("Generating comprehensive forecast chart...")
135 | # plt.style.use('seaborn-v0_8-whitegrid')
136 | fig, (ax1, ax2) = plt.subplots(
137 | 2, 1, figsize=(15, 10), sharex=True,
138 | gridspec_kw={'height_ratios': [3, 1]}
139 | )
140 |
141 | hist_time = hist_df['timestamps']
142 | last_hist_time = hist_time.iloc[-1]
143 | pred_time = pd.to_datetime([last_hist_time + timedelta(hours=i + 1) for i in range(len(close_preds_df))])
144 |
145 | ax1.plot(hist_time, hist_df['close'], color='royalblue', label='Historical Price', linewidth=1.5)
146 | mean_preds = close_preds_df.mean(axis=1)
147 | ax1.plot(pred_time, mean_preds, color='darkorange', linestyle='-', label='Mean Forecast')
148 | ax1.fill_between(pred_time, close_preds_df.min(axis=1), close_preds_df.max(axis=1), color='darkorange', alpha=0.2, label='Forecast Range (Min-Max)')
149 | ax1.set_title(f'{Config["SYMBOL"]} Probabilistic Price & Volume Forecast (Next {Config["PRED_HORIZON"]} Hours)', fontsize=16, weight='bold')
150 | ax1.set_ylabel('Price (USDT)')
151 | ax1.legend()
152 | ax1.grid(True, which='both', linestyle='--', linewidth=0.5)
153 |
154 | ax2.bar(hist_time, hist_df['volume'], color='skyblue', label='Historical Volume', width=0.03)
155 | ax2.bar(pred_time, volume_preds_df.mean(axis=1), color='sandybrown', label='Mean Forecasted Volume', width=0.03)
156 | ax2.set_ylabel('Volume')
157 | ax2.set_xlabel('Time (UTC)')
158 | ax2.legend()
159 | ax2.grid(True, which='both', linestyle='--', linewidth=0.5)
160 |
161 | separator_time = hist_time.iloc[-1] + timedelta(minutes=30)
162 | for ax in [ax1, ax2]:
163 | ax.axvline(x=separator_time, color='red', linestyle='--', linewidth=1.5, label='_nolegend_')
164 | ax.tick_params(axis='x', rotation=30)
165 |
166 | fig.tight_layout()
167 | chart_path = Config["REPO_PATH"] / 'prediction_chart.png'
168 | fig.savefig(chart_path, dpi=120)
169 | plt.close(fig)
170 | print(f"Chart saved to: {chart_path}")
171 |
172 |
173 | def update_html(upside_prob, vol_amp_prob):
174 | """
175 | Updates the index.html file with the latest metrics and timestamp.
176 | This version uses a more robust lambda function for replacement to avoid formatting errors.
177 | """
178 | print("Updating index.html...")
179 | html_path = Config["REPO_PATH"] / 'index.html'
180 | now_utc_str = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
181 | upside_prob_str = f'{upside_prob:.1%}'
182 | vol_amp_prob_str = f'{vol_amp_prob:.1%}'
183 |
184 | with open(html_path, 'r', encoding='utf-8') as f:
185 | content = f.read()
186 |
187 | # Robustly replace content using lambda functions
188 | content = re.sub(
189 | r'().*?( )',
190 | lambda m: f'{m.group(1)}{now_utc_str}{m.group(2)}',
191 | content
192 | )
193 | content = re.sub(
194 | r'().*?(
)',
195 | lambda m: f'{m.group(1)}{upside_prob_str}{m.group(2)}',
196 | content
197 | )
198 | content = re.sub(
199 | r'().*?(
)',
200 | lambda m: f'{m.group(1)}{vol_amp_prob_str}{m.group(2)}',
201 | content
202 | )
203 |
204 | with open(html_path, 'w', encoding='utf-8') as f:
205 | f.write(content)
206 | print("HTML file updated successfully.")
207 |
208 |
209 | def git_commit_and_push(commit_message):
210 | """Adds, commits, and pushes specified files to the Git repository."""
211 | print("Performing Git operations...")
212 | try:
213 | os.chdir(Config["REPO_PATH"])
214 | subprocess.run(['git', 'add', 'prediction_chart.png', 'index.html'], check=True, capture_output=True, text=True)
215 | commit_result = subprocess.run(['git', 'commit', '-m', commit_message], check=True, capture_output=True, text=True)
216 | print(commit_result.stdout)
217 | push_result = subprocess.run(['git', 'push'], check=True, capture_output=True, text=True)
218 | print(push_result.stdout)
219 | print("Git push successful.")
220 | except subprocess.CalledProcessError as e:
221 | output = e.stdout if e.stdout else e.stderr
222 | if "nothing to commit" in output or "Your branch is up to date" in output:
223 | print("No new changes to commit or push.")
224 | else:
225 | print(f"A Git error occurred:\n--- STDOUT ---\n{e.stdout}\n--- STDERR ---\n{e.stderr}")
226 |
227 |
228 | def main_task(model):
229 | """Executes one full update cycle."""
230 | print("\n" + "=" * 60 + f"\nStarting update task at {datetime.now(timezone.utc)}\n" + "=" * 60)
231 | df_full = fetch_binance_data()
232 | df_for_model = df_full.iloc[:-1]
233 |
234 | close_preds, volume_preds, v_close_preds = make_prediction(df_for_model, model)
235 |
236 | hist_df_for_plot = df_for_model.tail(Config["HIST_POINTS"])
237 | hist_df_for_metrics = df_for_model.tail(Config["VOL_WINDOW"])
238 |
239 | upside_prob, vol_amp_prob = calculate_metrics(hist_df_for_metrics, close_preds, v_close_preds)
240 | create_plot(hist_df_for_plot, close_preds, volume_preds)
241 | update_html(upside_prob, vol_amp_prob)
242 |
243 | commit_message = f"Auto-update forecast for {datetime.now(timezone.utc):%Y-%m-%d %H:%M} UTC"
244 | git_commit_and_push(commit_message)
245 |
246 | # --- 新增的内存清理步骤 ---
247 | # 显式删除大的DataFrame对象,帮助垃圾回收器
248 | del df_full, df_for_model, close_preds, volume_preds, v_close_preds
249 | del hist_df_for_plot, hist_df_for_metrics
250 |
251 | # 强制执行垃圾回收
252 | gc.collect()
253 | # --- 内存清理结束 ---
254 |
255 | print("-" * 60 + "\n--- Task completed successfully ---\n" + "-" * 60 + "\n")
256 |
257 |
258 | def run_scheduler(model):
259 | """A continuous scheduler that runs the main task hourly."""
260 | while True:
261 | now = datetime.now(timezone.utc)
262 | next_run_time = (now + timedelta(hours=1)).replace(minute=0, second=5, microsecond=0)
263 | sleep_seconds = (next_run_time - now).total_seconds()
264 |
265 | if sleep_seconds > 0:
266 | print(f"Current time: {now:%Y-%m-%d %H:%M:%S UTC}.")
267 | print(f"Next run at: {next_run_time:%Y-%m-%d %H:%M:%S UTC}. Waiting for {sleep_seconds:.0f} seconds...")
268 | time.sleep(sleep_seconds)
269 |
270 | try:
271 | main_task(model)
272 | except Exception as e:
273 | print(f"\n!!!!!! A critical error occurred in the main task !!!!!!!")
274 | print(f"Error: {e}")
275 | import traceback
276 | traceback.print_exc()
277 | print("Retrying in 5 minutes...")
278 | print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n")
279 | time.sleep(300)
280 |
281 |
282 | if __name__ == '__main__':
283 | model_path = Path(Config["MODEL_PATH"])
284 | model_path.mkdir(parents=True, exist_ok=True)
285 |
286 | loaded_model = load_model()
287 | main_task(loaded_model) # Run once on startup
288 | run_scheduler(loaded_model) # Start the schedule
--------------------------------------------------------------------------------
/model/kronos.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import torch
4 | from huggingface_hub import PyTorchModelHubMixin
5 | import sys
6 |
7 | from tqdm import trange
8 |
9 | sys.path.append("../")
10 | from model.module import *
11 |
12 |
13 | class KronosTokenizer(nn.Module, PyTorchModelHubMixin):
14 | """
15 | KronosTokenizer module for tokenizing input data using a hybrid quantization approach.
16 |
17 | This tokenizer utilizes a combination of encoder and decoder Transformer blocks
18 | along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data.
19 |
20 | Args:
21 | d_in (int): Input dimension.
22 | d_model (int): Model dimension.
23 | n_heads (int): Number of attention heads.
24 | ff_dim (int): Feed-forward dimension.
25 | n_enc_layers (int): Number of encoder layers.
26 | n_dec_layers (int): Number of decoder layers.
27 | ffn_dropout_p (float): Dropout probability for feed-forward networks.
28 | attn_dropout_p (float): Dropout probability for attention mechanisms.
29 | resid_dropout_p (float): Dropout probability for residual connections.
30 | s1_bits (int): Number of bits for the pre token in BSQuantizer.
31 | s2_bits (int): Number of bits for the post token in BSQuantizer.
32 | beta (float): Beta parameter for BSQuantizer.
33 | gamma0 (float): Gamma0 parameter for BSQuantizer.
34 | gamma (float): Gamma parameter for BSQuantizer.
35 | zeta (float): Zeta parameter for BSQuantizer.
36 | group_size (int): Group size parameter for BSQuantizer.
37 |
38 | """
39 |
40 | def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, ffn_dropout_p, attn_dropout_p, resid_dropout_p, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):
41 |
42 | super().__init__()
43 | self.d_in = d_in
44 | self.d_model = d_model
45 | self.n_heads = n_heads
46 | self.ff_dim = ff_dim
47 | self.enc_layers = n_enc_layers
48 | self.dec_layers = n_dec_layers
49 | self.ffn_dropout_p = ffn_dropout_p
50 | self.attn_dropout_p = attn_dropout_p
51 | self.resid_dropout_p = resid_dropout_p
52 |
53 | self.s1_bits = s1_bits
54 | self.s2_bits = s2_bits
55 | self.codebook_dim = s1_bits + s2_bits # Total dimension of the codebook after quantization
56 | self.embed = nn.Linear(self.d_in, self.d_model)
57 | self.head = nn.Linear(self.d_model, self.d_in)
58 |
59 | # Encoder Transformer Blocks
60 | self.encoder = nn.ModuleList([
61 | TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
62 | for _ in range(self.enc_layers - 1)
63 | ])
64 | # Decoder Transformer Blocks
65 | self.decoder = nn.ModuleList([
66 | TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
67 | for _ in range(self.dec_layers - 1)
68 | ])
69 | self.quant_embed = nn.Linear(in_features=self.d_model, out_features=self.codebook_dim) # Linear layer before quantization
70 | self.post_quant_embed_pre = nn.Linear(in_features=self.s1_bits, out_features=self.d_model) # Linear layer after quantization (pre part - s1 bits)
71 | self.post_quant_embed = nn.Linear(in_features=self.codebook_dim, out_features=self.d_model) # Linear layer after quantization (full codebook)
72 | self.tokenizer = BSQuantizer(self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size) # BSQuantizer module
73 |
74 | def forward(self, x):
75 | """
76 | Forward pass of the KronosTokenizer.
77 |
78 | Args:
79 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in).
80 |
81 | Returns:
82 | tuple: A tuple containing:
83 | - tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively,
84 | both of shape (batch_size, seq_len, d_in).
85 | - torch.Tensor: bsq_loss - Loss from the BSQuantizer.
86 | - torch.Tensor: quantized - Quantized representation from BSQuantizer.
87 | - torch.Tensor: z_indices - Indices from the BSQuantizer.
88 | """
89 | z = self.embed(x)
90 |
91 | for layer in self.encoder:
92 | z = layer(z)
93 |
94 | z = self.quant_embed(z) # (B, T, codebook)
95 |
96 | bsq_loss, quantized, z_indices = self.tokenizer(z)
97 |
98 | quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits)
99 | z_pre = self.post_quant_embed_pre(quantized_pre)
100 |
101 | z = self.post_quant_embed(quantized)
102 |
103 | # Decoder layers (for pre part - s1 bits)
104 | for layer in self.decoder:
105 | z_pre = layer(z_pre)
106 | z_pre = self.head(z_pre)
107 |
108 | # Decoder layers (for full codebook)
109 | for layer in self.decoder:
110 | z = layer(z)
111 | z = self.head(z)
112 |
113 | return (z_pre, z), bsq_loss, quantized, z_indices
114 |
115 | def indices_to_bits(self, x, half=False):
116 | """
117 | Converts indices to bit representations and scales them.
118 |
119 | Args:
120 | x (torch.Tensor): Indices tensor.
121 | half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False.
122 |
123 | Returns:
124 | torch.Tensor: Bit representation tensor.
125 | """
126 | if half:
127 | x1 = x[0] # Assuming x is a tuple of indices if half is True
128 | x2 = x[1]
129 | mask = 2 ** torch.arange(self.codebook_dim//2, device=x1.device, dtype=torch.long) # Create a mask for bit extraction
130 | x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half
131 | x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half
132 | x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations
133 | else:
134 | mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # Create a mask for bit extraction
135 | x = (x.unsqueeze(-1) & mask) != 0 # Extract bits
136 |
137 | x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1)
138 | q_scale = 1. / (self.codebook_dim ** 0.5) # Scaling factor
139 | x = x * q_scale
140 | return x
141 |
142 | def encode(self, x, half=False):
143 | """
144 | Encodes the input data into quantized indices.
145 |
146 | Args:
147 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in).
148 | half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False.
149 |
150 | Returns:
151 | torch.Tensor: Quantized indices from BSQuantizer.
152 | """
153 | z = self.embed(x)
154 | for layer in self.encoder:
155 | z = layer(z)
156 | z = self.quant_embed(z)
157 |
158 | bsq_loss, quantized, z_indices = self.tokenizer(z, half)
159 | return z_indices
160 |
161 | def decode(self, x, half=False):
162 | """
163 | Decodes quantized indices back to the input data space.
164 |
165 | Args:
166 | x (torch.Tensor): Quantized indices tensor.
167 | half (bool, optional): Whether the indices were generated with half quantization. Defaults to False.
168 |
169 | Returns:
170 | torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in).
171 | """
172 | quantized = self.indices_to_bits(x, half)
173 | z = self.post_quant_embed(quantized)
174 | for layer in self.decoder:
175 | z = layer(z)
176 | z = self.head(z)
177 | return z
178 |
179 |
180 | class Kronos(nn.Module, PyTorchModelHubMixin):
181 | """
182 | Kronos Model.
183 |
184 | Args:
185 | s1_bits (int): Number of bits for pre tokens.
186 | s2_bits (int): Number of bits for post tokens.
187 | n_layers (int): Number of Transformer blocks.
188 | d_model (int): Dimension of the model's embeddings and hidden states.
189 | n_heads (int): Number of attention heads in the MultiheadAttention layers.
190 | ff_dim (int): Dimension of the feedforward network in the Transformer blocks.
191 | ffn_dropout_p (float): Dropout probability for the feedforward network.
192 | attn_dropout_p (float): Dropout probability for the attention layers.
193 | resid_dropout_p (float): Dropout probability for residual connections.
194 | token_dropout_p (float): Dropout probability for token embeddings.
195 | learn_te (bool): Whether to use learnable temporal embeddings.
196 | """
197 |
198 | def __init__(self, s1_bits, s2_bits, n_layers, d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p, token_dropout_p, learn_te):
199 | super().__init__()
200 | self.s1_bits = s1_bits
201 | self.s2_bits = s2_bits
202 | self.n_layers = n_layers
203 | self.d_model = d_model
204 | self.n_heads = n_heads
205 | self.learn_te = learn_te
206 | self.ff_dim = ff_dim
207 | self.ffn_dropout_p = ffn_dropout_p
208 | self.attn_dropout_p = attn_dropout_p
209 | self.resid_dropout_p = resid_dropout_p
210 | self.token_dropout_p = token_dropout_p
211 |
212 | self.s1_vocab_size = 2 ** self.s1_bits
213 | self.token_drop = nn.Dropout(self.token_dropout_p)
214 | self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model)
215 | self.time_emb = TemporalEmbedding(self.d_model, self.learn_te)
216 | self.transformer = nn.ModuleList([
217 | TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
218 | for _ in range(self.n_layers)
219 | ])
220 | self.norm = RMSNorm(self.d_model)
221 | self.dep_layer = DependencyAwareLayer(self.d_model)
222 | self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model)
223 | self.apply(self._init_weights)
224 |
225 | def _init_weights(self, module):
226 |
227 | if isinstance(module, nn.Linear):
228 | nn.init.xavier_normal_(module.weight)
229 | if module.bias is not None:
230 | nn.init.zeros_(module.bias)
231 | elif isinstance(module, nn.Embedding):
232 | nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model ** -0.5)
233 | elif isinstance(module, nn.LayerNorm):
234 | nn.init.ones_(module.weight)
235 | nn.init.zeros_(module.bias)
236 | elif isinstance(module, RMSNorm):
237 | nn.init.ones_(module.weight)
238 |
239 | def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None):
240 | """
241 | Args:
242 | s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
243 | s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
244 | stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
245 | padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
246 | use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False.
247 | s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None.
248 |
249 | Returns:
250 | Tuple[torch.Tensor, torch.Tensor]:
251 | - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
252 | - s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size]
253 | """
254 | x = self.embedding([s1_ids, s2_ids])
255 | if stamp is not None:
256 | time_embedding = self.time_emb(stamp)
257 | x = x + time_embedding
258 | x = self.token_drop(x)
259 |
260 | for layer in self.transformer:
261 | x = layer(x, key_padding_mask=padding_mask)
262 |
263 | x = self.norm(x)
264 |
265 | s1_logits = self.head(x)
266 |
267 | if use_teacher_forcing:
268 | sibling_embed = self.embedding.emb_s1(s1_targets)
269 | else:
270 | s1_probs = F.softmax(s1_logits.detach(), dim=-1)
271 | sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape)
272 | sibling_embed = self.embedding.emb_s1(sample_s1_ids)
273 |
274 | x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings
275 | s2_logits = self.head.cond_forward(x2)
276 | return s1_logits, s2_logits
277 |
278 | def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None):
279 | """
280 | Decodes only the s1 tokens.
281 |
282 | This method performs a forward pass to predict only s1 tokens. It returns the s1 logits
283 | and the context representation from the Transformer, which can be used for subsequent s2 decoding.
284 |
285 | Args:
286 | s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
287 | s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
288 | stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
289 | padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
290 |
291 | Returns:
292 | Tuple[torch.Tensor, torch.Tensor]:
293 | - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
294 | - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model]
295 | """
296 | x = self.embedding([s1_ids, s2_ids])
297 | if stamp is not None:
298 | time_embedding = self.time_emb(stamp)
299 | x = x + time_embedding
300 | x = self.token_drop(x)
301 |
302 | for layer in self.transformer:
303 | x = layer(x, key_padding_mask=padding_mask)
304 |
305 | x = self.norm(x)
306 |
307 | s1_logits = self.head(x)
308 | return s1_logits, x
309 |
310 | def decode_s2(self, context, s1_ids, padding_mask=None):
311 | """
312 | Decodes the s2 tokens, conditioned on the context and s1 tokens.
313 |
314 | This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`)
315 | and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens.
316 |
317 | Args:
318 | context (torch.Tensor): Context representation from the transformer (output of decode_s1).
319 | Shape: [batch_size, seq_len, d_model]
320 | s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
321 | padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
322 |
323 | Returns:
324 | torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size]
325 | """
326 | sibling_embed = self.embedding.emb_s1(s1_ids)
327 | x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask)
328 | return self.head.cond_forward(x2)
329 |
330 |
331 | def top_k_top_p_filtering(
332 | logits,
333 | top_k: int = 0,
334 | top_p: float = 1.0,
335 | filter_value: float = -float("Inf"),
336 | min_tokens_to_keep: int = 1,
337 | ):
338 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
339 | Args:
340 | logits: logits distribution shape (batch size, vocabulary size)
341 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
342 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
343 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
344 | Make sure we keep at least min_tokens_to_keep per batch example in the output
345 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
346 | """
347 | if top_k > 0:
348 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
349 | # Remove all tokens with a probability less than the last token of the top-k
350 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
351 | logits[indices_to_remove] = filter_value
352 | return logits
353 |
354 | if top_p < 1.0:
355 | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
356 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
357 |
358 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
359 | sorted_indices_to_remove = cumulative_probs > top_p
360 | if min_tokens_to_keep > 1:
361 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
362 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
363 | # Shift the indices to the right to keep also the first token above the threshold
364 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
365 | sorted_indices_to_remove[..., 0] = 0
366 |
367 | # scatter sorted tensors to original indexing
368 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
369 | logits[indices_to_remove] = filter_value
370 | return logits
371 |
372 |
373 | def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True):
374 | logits = logits / temperature
375 | if top_k is not None or top_p is not None:
376 | if top_k > 0 or top_p < 1.0:
377 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
378 |
379 | probs = F.softmax(logits, dim=-1)
380 |
381 | if not sample_logits:
382 | _, x = top_k(probs, k=1, dim=-1)
383 | else:
384 | x = torch.multinomial(probs, num_samples=1)
385 |
386 | return x
387 |
388 |
389 | def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False):
390 | with torch.no_grad():
391 | batch_size = x.size(0)
392 | initial_seq_len = x.size(1)
393 | x = torch.clip(x, -clip, clip)
394 |
395 | device = x.device
396 | x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device)
397 | x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device)
398 | y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device)
399 |
400 | x_token = tokenizer.encode(x, half=True)
401 |
402 | def get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, pred_step):
403 |
404 | if current_seq_len <= max_context - pred_step:
405 | return torch.cat([x_stamp, y_stamp[:, :pred_step, :]], dim=1)
406 | else:
407 | start_idx = max_context - pred_step
408 | return torch.cat([x_stamp[:, -start_idx:, :], y_stamp[:, :pred_step, :]], dim=1)
409 |
410 | if verbose:
411 | ran = trange
412 | else:
413 | ran = range
414 | for i in ran(pred_len):
415 | current_seq_len = initial_seq_len + i
416 |
417 | if current_seq_len <= max_context:
418 | input_tokens = x_token
419 | else:
420 | input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
421 |
422 | current_stamp = get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, i)
423 |
424 | s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp)
425 | s1_logits = s1_logits[:, -1, :]
426 | sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)
427 |
428 | s2_logits = model.decode_s2(context, sample_pre)
429 | s2_logits = s2_logits[:, -1, :]
430 | sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)
431 |
432 | x_token[0] = torch.cat([x_token[0], sample_pre], dim=1)
433 | x_token[1] = torch.cat([x_token[1], sample_post], dim=1)
434 |
435 | input_tokens = [t[:, -max_context:].contiguous() for t in x_token]
436 | z = tokenizer.decode(input_tokens, half=True)
437 | z = z.reshape(batch_size, sample_count, z.size(1), z.size(2))
438 | preds = z.cpu().numpy()
439 | # preds = np.mean(preds, axis=1)
440 |
441 | return preds
442 |
443 |
444 | def calc_time_stamps(x_timestamp):
445 | time_df = pd.DataFrame()
446 | time_df['minute'] = x_timestamp.dt.minute
447 | time_df['hour'] = x_timestamp.dt.hour
448 | time_df['weekday'] = x_timestamp.dt.weekday
449 | time_df['day'] = x_timestamp.dt.day
450 | time_df['month'] = x_timestamp.dt.month
451 | return time_df
452 |
453 |
454 | class KronosPredictor:
455 |
456 | def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5):
457 | self.tokenizer = tokenizer
458 | self.model = model
459 | self.max_context = max_context
460 | self.clip = clip
461 | self.price_cols = ['open', 'high', 'low', 'close']
462 | self.vol_col = 'volume'
463 | self.amt_vol = 'amount'
464 | self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month']
465 | self.device = device
466 |
467 | self.tokenizer = self.tokenizer.to(self.device)
468 | self.model = self.model.to(self.device)
469 |
470 | def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose):
471 |
472 | x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device)
473 | x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device)
474 | y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device)
475 |
476 | preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len,
477 | self.clip, T, top_k, top_p, sample_count, verbose)
478 | preds = preds[:, :, -pred_len:, :]
479 | return preds
480 |
481 | def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
482 |
483 | if not isinstance(df, pd.DataFrame):
484 | raise ValueError("Input must be a pandas DataFrame.")
485 |
486 | if not all(col in df.columns for col in self.price_cols):
487 | raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.")
488 |
489 | df = df.copy()
490 | if self.vol_col not in df.columns:
491 | df[self.vol_col] = 0.0 # Fill missing volume with zeros
492 | df[self.amt_vol] = 0.0 # Fill missing amount with zeros
493 | if self.amt_vol not in df.columns and self.vol_col in df.columns:
494 | df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1)
495 |
496 | if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any():
497 | raise ValueError("Input DataFrame contains NaN values in price or volume columns.")
498 |
499 | x_time_df = calc_time_stamps(x_timestamp)
500 | y_time_df = calc_time_stamps(y_timestamp)
501 |
502 | x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32)
503 | x_stamp = x_time_df.values.astype(np.float32)
504 | y_stamp = y_time_df.values.astype(np.float32)
505 |
506 | x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
507 |
508 | x = (x - x_mean) / (x_std + 1e-5)
509 | x = np.clip(x, -self.clip, self.clip)
510 |
511 | x = x[np.newaxis, :]
512 | x_stamp = x_stamp[np.newaxis, :]
513 | y_stamp = y_stamp[np.newaxis, :]
514 |
515 | preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose)
516 |
517 | preds = preds.squeeze(0)
518 | preds = preds * (x_std[np.newaxis, :] + 1e-5) + x_mean[np.newaxis, :]
519 |
520 | close_preds = preds[:, :, 3].swapaxes(0, 1)
521 | volume_preds = preds[:, :, 4].swapaxes(0, 1)
522 |
523 | close_df = pd.DataFrame(close_preds, columns=[f"pred-{i+1}" for i in range(sample_count)], index=y_timestamp)
524 | volume_df = pd.DataFrame(volume_preds, columns=[f"pred-{i + 1}" for i in range(sample_count)], index=y_timestamp)
525 |
526 | return close_df, volume_df
527 |
--------------------------------------------------------------------------------
/model/module.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from einops import rearrange, reduce
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Function
7 | import torch.nn.functional as F
8 |
9 |
10 | class DifferentiableEntropyFunction(Function):
11 | @staticmethod
12 | def forward(ctx, zq, basis, K, eps):
13 | zb = (zq + 1) / 2
14 | zi = ((zb * basis).sum(-1)).to(torch.int64)
15 | cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype),
16 | 0,
17 | zi.flatten(),
18 | torch.ones_like(zi.flatten()).to(zq.dtype),
19 | 'sum')
20 | prob = (cnt + eps) / (cnt + eps).sum()
21 | H = -(prob * torch.log(prob)).sum()
22 | ctx.save_for_backward(zq, zi, prob)
23 | ctx.K = K
24 | return H
25 |
26 | @staticmethod
27 | def backward(ctx, grad_output):
28 | zq, zi, prob = ctx.saved_tensors
29 | grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
30 | reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
31 | grad_input = reord_grad.unsqueeze(-1) * zq
32 | return grad_input, None, None, None, None
33 |
34 |
35 | def codebook_entropy(zq, basis, K, eps=1e-4):
36 | return DifferentiableEntropyFunction.apply(zq, basis, K, eps)
37 |
38 |
39 | class BinarySphericalQuantizer(nn.Module):
40 | def __init__(self, embed_dim, beta, gamma0, gamma, zeta,
41 | input_format='bchw',
42 | soft_entropy=True, group_size=9,
43 | persample_entropy_compute='analytical',
44 | cb_entropy_compute='group',
45 | l2_norm=True,
46 | inv_temperature=1):
47 | """
48 | Paper link: https://arxiv.org/pdf/2406.07548.pdf
49 | Here we use the official implementation of the BinarySphericalQuantizer.
50 | """
51 | super().__init__()
52 | self.embed_dim = embed_dim
53 | self.beta = beta # loss weight for commit loss
54 | self.gamma0 = gamma0 # loss weight for entropy penalty
55 | self.gamma = gamma # loss weight for entropy penalty
56 | self.zeta = zeta # loss weight for entire entropy penalty
57 | self.input_format = input_format
58 | assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size"
59 | self.num_groups = self.embed_dim // group_size
60 | self.group_size = group_size
61 | assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'"
62 | assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'"
63 | self.persample_entropy_compute = persample_entropy_compute
64 | self.cb_entropy_compute = cb_entropy_compute
65 | self.l2_norm = l2_norm
66 | self.inv_temperature = inv_temperature
67 |
68 | self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1))
69 | self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1))
70 |
71 | self.num_dimensions = 2 ** embed_dim
72 | self.bits_per_index = embed_dim
73 |
74 | # we only need to keep the codebook portion up to the group size
75 | # because we approximate the H loss with this subcode
76 | group_codes = torch.arange(2 ** self.group_size)
77 | group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
78 | self.register_buffer('group_codebook', group_codebook, persistent=False)
79 |
80 | self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf
81 |
82 | def quantize(self, z):
83 | assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}"
84 |
85 | zhat = torch.where(z > 0,
86 | torch.tensor(1, dtype=z.dtype, device=z.device),
87 | torch.tensor(-1, dtype=z.dtype, device=z.device))
88 | return z + (zhat - z).detach()
89 |
90 | def forward(self, z):
91 | # if self.input_format == 'bchw':
92 | # z = rearrange(z, 'b c h w -> b h w c')
93 | zq = self.quantize(z)
94 |
95 | indices = self.codes_to_indexes(zq.detach())
96 | group_indices = self.codes_to_group_indexes(zq.detach())
97 | if not self.training:
98 | used_codes = torch.unique(indices, return_counts=False)
99 | else:
100 | used_codes = None
101 |
102 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
103 |
104 | if self.soft_entropy:
105 | persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z)
106 | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
107 | else:
108 | zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
109 | persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample)
110 | cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim)
111 | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
112 |
113 | zq = zq * q_scale
114 |
115 | # commit loss
116 | commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
117 |
118 | # if self.input_format == 'bchw':
119 | # zq = rearrange(zq, 'b h w c -> b c h w')
120 |
121 | return (
122 | zq,
123 | commit_loss + self.zeta * entropy_penalty / self.inv_temperature,
124 | {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices,
125 | "avg_prob": avg_prob}
126 | )
127 |
128 | def soft_entropy_loss(self, z):
129 | # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size
130 | # the sub-code is the last group_size bits of the full code
131 | group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1)
132 | divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size)
133 |
134 | # we calculate the distance between the divided_z and the codebook for each subgroup
135 | distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book)
136 | prob = (-distance * self.inv_temperature).softmax(dim=-1)
137 | if self.persample_entropy_compute == 'analytical':
138 | if self.l2_norm:
139 | p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature)
140 | else:
141 | p = torch.sigmoid(-4 * z * self.inv_temperature)
142 | prob = torch.stack([p, 1 - p], dim=-1)
143 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
144 | else:
145 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
146 |
147 | # macro average of the probability of each subgroup
148 | avg_prob = reduce(prob, '... g d ->g d', 'mean')
149 | codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
150 |
151 | # the approximation of the entropy is the sum of the entropy of each subgroup
152 | return per_sample_entropy, codebook_entropy.sum(), avg_prob
153 |
154 | def get_hard_per_sample_entropy(self, zb_by_sample):
155 | probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1]
156 | persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8)
157 | persample_entropy = persample_entropy.sum(-1)
158 | return persample_entropy.mean()
159 |
160 | def codes_to_indexes(self, zhat):
161 | """Converts a `code` to an index in the codebook.
162 | Args:
163 | zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
164 | """
165 | assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}"
166 | return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)
167 |
168 | def codes_to_group_indexes(self, zhat):
169 | """Converts a `code` to a list of indexes (in groups) in the codebook.
170 | Args:
171 | zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
172 | """
173 | zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size)
174 | return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)
175 |
176 | def indexes_to_codes(self, indices):
177 | """Inverse of `indexes_to_codes`."""
178 | indices = indices.unsqueeze(-1)
179 | codes_non_centered = torch.remainder(
180 | torch.floor_divide(indices, self.basis), 2
181 | )
182 | return codes_non_centered * 2 - 1
183 |
184 | def group_indexes_to_codes(self, group_indices):
185 | """Inverse of `group_indexes_to_codes`."""
186 | group_indices = group_indices.unsqueeze(-1)
187 | codes_non_centered = torch.remainder(
188 | torch.floor_divide(group_indices, self.group_basis), 2
189 | )
190 | codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)')
191 | return codes_non_centered * 2 - 1
192 |
193 | def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
194 | if normalize:
195 | probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True)
196 | else:
197 | probs = count
198 | H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
199 | return H
200 |
201 | def get_group_codebook_entry(self, group_indices):
202 | z_q = self.group_indexes_to_codes(group_indices)
203 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
204 | z_q = z_q * q_scale
205 | if self.input_format == 'bchw':
206 | h, w = int(z_q.shape[1] ** 0.5)
207 | assert h * w == z_q.shape[1], 'Invalid sequence length'
208 | z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
209 | return z_q
210 |
211 | def get_codebook_entry(self, indices):
212 | z_q = self.indexes_to_codes(indices)
213 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
214 | z_q = z_q * q_scale
215 | if self.input_format == 'bchw':
216 | h, w = int(z_q.shape[1] ** 0.5)
217 | assert h * w == z_q.shape[1], 'Invalid sequence length'
218 | z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
219 | return z_q
220 |
221 |
222 | class BSQuantizer(nn.Module):
223 |
224 | def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):
225 | super().__init__()
226 | self.codebook_dim = s1_bits + s2_bits
227 | self.s1_bits = s1_bits
228 | self.s2_bits = s2_bits
229 | self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size)
230 |
231 | def bits_to_indices(self, bits):
232 | bits = (bits >= 0).to(torch.long)
233 | indices = 2 ** torch.arange(
234 | 0,
235 | bits.shape[-1],
236 | 1,
237 | dtype=torch.long,
238 | device=bits.device,
239 | )
240 | return (bits * indices).sum(-1)
241 |
242 | def forward(self, z, half=False):
243 | z = F.normalize(z, dim=-1)
244 | quantized, bsq_loss, metrics = self.bsq(z)
245 | if half:
246 | q_pre = quantized[:, :, :self.s1_bits]
247 | q_post = quantized[:, :, self.s1_bits:]
248 | z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)]
249 | else:
250 | z_indices = self.bits_to_indices(quantized)
251 | return bsq_loss, quantized, z_indices
252 |
253 |
254 | class RMSNorm(torch.nn.Module):
255 | def __init__(self, dim: int, eps: float = 1e-5):
256 | super().__init__()
257 | self.eps = eps
258 | self.weight = nn.Parameter(torch.ones(dim))
259 |
260 | def _norm(self, x):
261 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
262 |
263 | def forward(self, x):
264 | output = self._norm(x.float()).type_as(x)
265 | return output * self.weight
266 |
267 |
268 | class FeedForward(nn.Module):
269 | def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0):
270 | super().__init__()
271 |
272 | self.w1 = nn.Linear(d_model, ff_dim, bias=False)
273 | self.w3 = nn.Linear(d_model, ff_dim, bias=False)
274 | self.w2 = nn.Linear(ff_dim, d_model, bias=False)
275 | self.ffn_dropout = nn.Dropout(ffn_dropout_p)
276 |
277 | def forward(self, x):
278 | return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
279 |
280 |
281 | class RotaryPositionalEmbedding(nn.Module):
282 | def __init__(self, dim):
283 | super().__init__()
284 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
285 | self.register_buffer("inv_freq", inv_freq)
286 | self.seq_len_cached = None
287 | self.cos_cached = None
288 | self.sin_cached = None
289 |
290 | def _update_cos_sin_cache(self, x, seq_len):
291 | if seq_len != self.seq_len_cached:
292 | self.seq_len_cached = seq_len
293 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
294 | freqs = torch.einsum('i,j->ij', t, self.inv_freq)
295 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
296 | self.cos_cached = emb.cos()[None, None, :, :]
297 | self.sin_cached = emb.sin()[None, None, :, :]
298 | return self.cos_cached, self.sin_cached
299 |
300 | def forward(self, q, k):
301 | cos, sin = self._update_cos_sin_cache(q, q.shape[-2])
302 | return (
303 | (q * cos) + (self._rotate_half(q) * sin),
304 | (k * cos) + (self._rotate_half(k) * sin),
305 | )
306 |
307 | def _rotate_half(self, x):
308 | x1, x2 = x.chunk(2, dim=-1)
309 | return torch.cat((-x2, x1), dim=-1)
310 |
311 |
312 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, training=True) -> torch.Tensor:
313 | L, S = query.size(-2), key.size(-2)
314 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
315 | attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
316 |
317 | if is_causal:
318 | assert attn_mask is None
319 | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(query.device)
320 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
321 | attn_bias.to(query.dtype)
322 |
323 | attn_weight = query @ key.transpose(-2, -1) * scale_factor
324 | attn_weight += attn_bias
325 |
326 | if attn_mask is not None:
327 | attn_mask_bias = torch.zeros_like(attn_weight)
328 | if attn_mask.dtype == torch.bool:
329 | attn_mask_bias.masked_fill_(attn_mask, float("-inf"))
330 | else:
331 | attn_mask_bias += attn_mask
332 | attn_weight += attn_mask_bias
333 |
334 | attn_weight = torch.softmax(attn_weight, dim=-1)
335 | attn_weight = torch.dropout(attn_weight, dropout_p, train=training)
336 | return attn_weight @ value
337 |
338 |
339 | class MultiHeadAttentionWithRoPE(nn.Module):
340 | def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0):
341 | super().__init__()
342 | self.d_model = d_model
343 | self.n_heads = n_heads
344 | self.head_dim = d_model // n_heads
345 |
346 | self.q_proj = nn.Linear(d_model, d_model)
347 | self.k_proj = nn.Linear(d_model, d_model)
348 | self.v_proj = nn.Linear(d_model, d_model)
349 | self.out_proj = nn.Linear(d_model, d_model)
350 | self.rotary = RotaryPositionalEmbedding(self.head_dim)
351 | self.attn_dropout_p = attn_dropout_p
352 | self.resid_dropout = nn.Dropout(resid_dropout_p)
353 |
354 | def forward(self, x, key_padding_mask=None):
355 | batch_size, seq_len, _ = x.shape
356 |
357 | q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
358 | k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
359 | v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
360 |
361 | q, k = self.rotary(q, k)
362 |
363 | if key_padding_mask is not None:
364 | attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len]
365 | attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len]
366 | else:
367 | attn_mask = None
368 |
369 | attn_output = scaled_dot_product_attention(
370 | q, k, v,
371 | attn_mask=attn_mask,
372 | dropout_p=self.attn_dropout_p,
373 | is_causal=True,
374 | training=self.training
375 | )
376 |
377 | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
378 | return self.resid_dropout(self.out_proj(attn_output))
379 |
380 |
381 | class MultiHeadCrossAttentionWithRoPE(nn.Module):
382 | def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0):
383 | super().__init__()
384 | self.d_model = d_model
385 | self.n_heads = n_heads
386 | self.head_dim = d_model // n_heads
387 |
388 | self.q_proj = nn.Linear(d_model, d_model)
389 | self.k_proj = nn.Linear(d_model, d_model)
390 | self.v_proj = nn.Linear(d_model, d_model)
391 | self.out_proj = nn.Linear(d_model, d_model)
392 | self.rotary = RotaryPositionalEmbedding(self.head_dim)
393 | self.attn_dropout_p = attn_dropout_p
394 | self.resid_dropout = nn.Dropout(resid_dropout)
395 |
396 | def forward(self, query, key, value, key_padding_mask=None):
397 | batch_size, q_len, _ = query.shape
398 | _, seq_len, _ = key.shape
399 |
400 | q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2)
401 | k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
402 | v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
403 |
404 | q, k = self.rotary(q, k)
405 |
406 | if key_padding_mask is not None:
407 | attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
408 | attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1)
409 | else:
410 | attn_mask = None
411 |
412 | is_causal_flag = self.training
413 |
414 | attn_output = scaled_dot_product_attention(
415 | q, k, v,
416 | attn_mask=attn_mask,
417 | dropout_p=self.attn_dropout_p,
418 | is_causal=is_causal_flag,
419 | training=self.training
420 | )
421 |
422 | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model)
423 | return self.resid_dropout(self.out_proj(attn_output))
424 |
425 |
426 | class HierarchicalEmbedding(nn.Module):
427 | def __init__(self, s1_bits, s2_bits, d_model=256):
428 | super().__init__()
429 | self.s1_bits = s1_bits
430 | self.s2_bits = s2_bits
431 |
432 | vocab_s1 = 2 ** s1_bits
433 | vocab_s2 = 2 ** s2_bits
434 |
435 | self.emb_s1 = nn.Embedding(vocab_s1, d_model)
436 | self.emb_s2 = nn.Embedding(vocab_s2, d_model)
437 | self.d_model = d_model
438 | self.fusion_proj = nn.Linear(d_model * 2, d_model)
439 |
440 | nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5)
441 | nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5)
442 |
443 | def forward(self, token_ids):
444 | """Inputs:
445 | token_ids: [batch_size, seq_len] token ID
446 | Output: [batch_size, seq_len, d_model]
447 | """
448 | if isinstance(token_ids, tuple) or isinstance(token_ids, list):
449 | s1_ids, s2_ids = token_ids
450 | else:
451 | s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits)
452 | s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model)
453 | s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model)
454 | return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1))
455 |
456 |
457 | class DependencyAwareLayer(nn.Module):
458 | def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0):
459 | super().__init__()
460 | self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout)
461 | self.norm = RMSNorm(d_model)
462 |
463 | def forward(self, hidden_states, sibling_embed, key_padding_mask=None):
464 | """hidden_states: [batch, seq_len, d_model]
465 | sibling_embed: Embedding from another subtoken
466 | """
467 | attn_out = self.cross_attn(
468 | query=sibling_embed,
469 | key=hidden_states,
470 | value=hidden_states,
471 | key_padding_mask=key_padding_mask
472 | )
473 | return self.norm(hidden_states + attn_out)
474 |
475 |
476 | class TransformerBlock(nn.Module):
477 | def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0):
478 | super().__init__()
479 | self.norm1 = RMSNorm(d_model)
480 | self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p)
481 | self.norm2 = RMSNorm(d_model)
482 | self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p)
483 |
484 | def forward(self, x, key_padding_mask=None):
485 | residual = x
486 | x = self.norm1(x)
487 | attn_out = self.self_attn(x, key_padding_mask=key_padding_mask)
488 | x = residual + attn_out
489 |
490 | residual = x
491 | x = self.norm2(x)
492 | ffn_out = self.ffn(x)
493 | x = residual + ffn_out
494 | return x
495 |
496 |
497 | class DualHead(nn.Module):
498 | def __init__(self, s1_bits, s2_bits, d_model):
499 | super().__init__()
500 | self.vocab_s1 = 2 ** s1_bits
501 | self.vocab_s2 = 2 ** s2_bits
502 | self.proj_s1 = nn.Linear(d_model, self.vocab_s1)
503 | self.proj_s2 = nn.Linear(d_model, self.vocab_s2)
504 |
505 | def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None):
506 | if padding_mask is not None:
507 | valid_mask = (padding_mask == 0)
508 | s1_logits = s1_logits[valid_mask]
509 | s2_logits = s2_logits[valid_mask]
510 | s1_targets = s1_targets[valid_mask]
511 | s2_targets = s2_targets[valid_mask]
512 | ce_s1 = F.cross_entropy(s1_logits, s1_targets)
513 | ce_s2 = F.cross_entropy(s2_logits, s2_targets)
514 | else:
515 | ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1))
516 | ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1))
517 | ce_loss = (ce_s1 + ce_s2) / 2
518 | return ce_loss, ce_s1, ce_s2
519 |
520 | def forward(self, x):
521 | return self.proj_s1(x)
522 |
523 | def cond_forward(self, x2):
524 | return self.proj_s2(x2)
525 |
526 |
527 | class FixedEmbedding(nn.Module):
528 | def __init__(self, c_in, d_model):
529 | super(FixedEmbedding, self).__init__()
530 |
531 | w = torch.zeros(c_in, d_model).float()
532 | w.require_grad = False
533 |
534 | position = torch.arange(0, c_in).float().unsqueeze(1)
535 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
536 |
537 | w[:, 0::2] = torch.sin(position * div_term)
538 | w[:, 1::2] = torch.cos(position * div_term)
539 |
540 | self.emb = nn.Embedding(c_in, d_model)
541 | self.emb.weight = nn.Parameter(w, requires_grad=False)
542 |
543 | def forward(self, x):
544 | return self.emb(x).detach()
545 |
546 |
547 | class TemporalEmbedding(nn.Module):
548 | def __init__(self, d_model, learn_pe):
549 | super(TemporalEmbedding, self).__init__()
550 |
551 | minute_size = 60
552 | hour_size = 24
553 | weekday_size = 7
554 | day_size = 32
555 | month_size = 13
556 |
557 | Embed = FixedEmbedding if not learn_pe else nn.Embedding
558 | self.minute_embed = Embed(minute_size, d_model)
559 | self.hour_embed = Embed(hour_size, d_model)
560 | self.weekday_embed = Embed(weekday_size, d_model)
561 | self.day_embed = Embed(day_size, d_model)
562 | self.month_embed = Embed(month_size, d_model)
563 |
564 | def forward(self, x):
565 | x = x.long()
566 |
567 | minute_x = self.minute_embed(x[:, :, 0])
568 | hour_x = self.hour_embed(x[:, :, 1])
569 | weekday_x = self.weekday_embed(x[:, :, 2])
570 | day_x = self.day_embed(x[:, :, 3])
571 | month_x = self.month_embed(x[:, :, 4])
572 |
573 | return hour_x + weekday_x + day_x + month_x + minute_x
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
--------------------------------------------------------------------------------