├── img └── logo.png ├── prediction_chart.png ├── model ├── __pycache__ │ ├── kronos.cpython-39.pyc │ ├── module.cpython-39.pyc │ └── __init__.cpython-39.pyc ├── __init__.py ├── kronos.py └── module.py ├── index.html ├── style.css └── update_predictions.py /img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyu-coder/Kronos-demo/HEAD/img/logo.png -------------------------------------------------------------------------------- /prediction_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyu-coder/Kronos-demo/HEAD/prediction_chart.png -------------------------------------------------------------------------------- /model/__pycache__/kronos.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyu-coder/Kronos-demo/HEAD/model/__pycache__/kronos.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/module.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyu-coder/Kronos-demo/HEAD/model/__pycache__/module.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiyu-coder/Kronos-demo/HEAD/model/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .kronos import KronosTokenizer, Kronos, KronosPredictor 2 | 3 | model_dict = { 4 | 'kronos_tokenizer': KronosTokenizer, 5 | 'kronos': Kronos, 6 | 'kronos_predictor': KronosPredictor 7 | } 8 | 9 | 10 | def get_model_class(model_name): 11 | if model_name in model_dict: 12 | return model_dict[model_name] 13 | else: 14 | print(f"Model {model_name} not found in model_dict") 15 | raise NotImplementedError 16 | 17 | 18 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Kronos Live Forecast | BTC/USDT 7 | 8 | 9 | 10 | 11 | 12 |
13 |
14 |

Kronos: Live Probabilistic Forecast

15 |

A Demo of "Kronos: A Foundation Model for the Language of Financial Markets"

16 | View Project on GitHub 17 |
18 |
19 | 20 |
21 | 22 |
23 |

Live BTC/USDT Forecast Dashboard

24 | 28 | 29 |
30 |
31 |

Upside Probability (Next 24h)

32 |

50.0%

33 |
The model's confidence that the price in 24 hours will be higher than the last known price.
34 |
35 |
36 |

Volatility Amplification (Next 24h)

37 |

90.0%

38 |
The probability that predicted volatility over the next 24h will exceed recent historical volatility.
39 |
40 |
41 | 42 |
43 |

24-Hour Probabilistic Forecast

44 |

The chart below shows the historical price (blue) and the probabilistic forecast (orange). The orange line is the mean of multiple Monte Carlo simulations, and the shaded area represents the full range of predicted outcomes, indicating forecast uncertainty.

45 |
46 | BTC/USDT Price and Volume Forecast Chart 47 |
48 |
49 |
50 | 51 | 52 |
53 |

Methodology Overview

54 |

This demo showcases the forecasting results of Kronos, a foundation model pre-trained on the "language" of financial markets. The predictions are generated using the following process:

55 | 61 |
62 | 63 | 64 |
65 |

About The Kronos Project

66 |

Kronos is the first open-source foundation model for financial candlesticks (K-lines), trained on data from over 45 global exchanges. It is designed to serve as a unified model for diverse quantitative finance tasks.

67 | 71 |
72 |
73 | 74 | 79 | 80 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | /* --- Global Styles & Variables --- */ 2 | :root { 3 | --primary-color: #3498db; 4 | --secondary-color: #2c3e50; 5 | --light-color: #ecf0f1; 6 | --dark-color: #34495e; 7 | --background-color: #f4f7f6; 8 | --card-bg-color: #ffffff; 9 | --text-color: #333; 10 | --subtle-text-color: #6c757d; 11 | --border-color: #e9ecef; 12 | } 13 | 14 | body { 15 | font-family: 'Inter', 'Lato', sans-serif; 16 | line-height: 1.7; 17 | margin: 0; 18 | padding: 0; 19 | background-color: var(--background-color); 20 | color: var(--text-color); 21 | } 22 | 23 | .container { 24 | max-width: 960px; 25 | margin: auto; 26 | padding: 0 1.5rem; 27 | } 28 | 29 | h1, h2, h3 { 30 | font-family: 'Roboto Slab', serif; 31 | color: var(--dark-color); 32 | line-height: 1.3; 33 | } 34 | 35 | h1 { font-size: 2.8rem; } 36 | h2 { font-size: 2.2rem; border-bottom: 3px solid var(--border-color); padding-bottom: 0.6rem; margin-bottom: 1.5rem; } 37 | h3 { font-size: 1.4rem; color: var(--secondary-color); } 38 | 39 | p { 40 | margin-bottom: 1rem; 41 | } 42 | 43 | /* --- Header --- */ 44 | header { 45 | background: var(--secondary-color); 46 | color: var(--light-color); 47 | padding: 2.5rem 0; 48 | text-align: center; 49 | border-bottom: 5px solid var(--primary-color); 50 | } 51 | 52 | header .project-title { 53 | margin: 0 0 0.5rem 0; 54 | color: var(--light-color); 55 | font-size: 2.8rem; 56 | } 57 | 58 | header .subtitle { 59 | font-size: 1.25rem; 60 | color: #bdc3c7; 61 | margin-bottom: 2rem; 62 | font-style: italic; 63 | } 64 | 65 | /* UPDATED: Style for the GitHub link to ensure high visibility in all states */ 66 | header .repo-link, 67 | header .repo-link:link, 68 | header .repo-link:visited { 69 | color: var(--light-color); /* Ensures normal and visited links are light-colored */ 70 | background-color: transparent; 71 | border: 2px solid var(--primary-color); 72 | padding: 0.6rem 1.5rem; 73 | border-radius: 50px; /* Pill shape */ 74 | text-decoration: none; 75 | font-weight: 600; 76 | transition: background-color 0.3s, color 0.3s; 77 | } 78 | 79 | header .repo-link:hover, 80 | header .repo-link:active { 81 | background-color: var(--primary-color); 82 | color: var(--card-bg-color); 83 | } 84 | 85 | 86 | /* --- Main Content & Sections --- */ 87 | main { 88 | margin: 2.5rem auto; 89 | } 90 | 91 | section { 92 | background: var(--card-bg-color); 93 | margin-bottom: 2.5rem; 94 | padding: 2rem; 95 | border-radius: 12px; 96 | box-shadow: 0 4px 15px rgba(0,0,0,0.07); 97 | } 98 | 99 | /* --- Dashboard Section --- */ 100 | #live-dashboard .metadata { 101 | display: flex; 102 | justify-content: space-between; 103 | flex-wrap: wrap; 104 | font-style: italic; 105 | color: var(--subtle-text-color); 106 | margin-bottom: 2rem; 107 | padding-bottom: 1rem; 108 | border-bottom: 1px solid var(--border-color); 109 | } 110 | 111 | .metrics-container { 112 | display: flex; 113 | justify-content: space-around; 114 | gap: 1.5rem; 115 | text-align: center; 116 | margin-bottom: 2rem; 117 | } 118 | 119 | .metric-card { 120 | flex: 1; 121 | min-width: 220px; 122 | padding: 1.5rem; 123 | background-color: #f8f9fa; 124 | border-radius: 8px; 125 | border: 1px solid var(--border-color); 126 | transition: transform 0.2s, box-shadow 0.2s; 127 | } 128 | 129 | .metric-card:hover { 130 | transform: translateY(-5px); 131 | box-shadow: 0 6px 20px rgba(0,0,0,0.1); 132 | } 133 | 134 | .metric-card h3 { 135 | margin-top: 0; 136 | font-size: 1.1rem; 137 | } 138 | 139 | .metric-value { 140 | font-size: 3rem; 141 | font-weight: 700; 142 | color: var(--primary-color); 143 | margin: 0.5rem 0; 144 | } 145 | 146 | .metric-desc { 147 | font-size: 0.9rem; 148 | color: var(--subtle-text-color); 149 | } 150 | 151 | .chart-wrapper p { 152 | max-width: 80ch; 153 | } 154 | 155 | .chart-container { 156 | width: 100%; 157 | text-align: center; 158 | margin-top: 1rem; 159 | } 160 | 161 | .chart-img { 162 | max-width: 100%; 163 | height: auto; 164 | border: 1px solid var(--border-color); 165 | border-radius: 8px; 166 | } 167 | 168 | /* --- Methodology & About Sections --- */ 169 | #methodology ul { 170 | list-style-type: disc; 171 | padding-left: 20px; 172 | } 173 | 174 | #methodology li { 175 | margin-bottom: 0.75rem; 176 | } 177 | 178 | .project-links { 179 | margin-top: 1.5rem; 180 | display: flex; 181 | gap: 1rem; 182 | } 183 | 184 | .btn { 185 | text-decoration: none; 186 | padding: 0.8rem 1.8rem; 187 | border-radius: 5px; 188 | font-weight: 600; 189 | transition: all 0.3s ease; 190 | border: none; 191 | } 192 | .btn-primary { 193 | background-color: var(--primary-color); 194 | color: white; 195 | } 196 | .btn-primary:hover { 197 | background-color: #2980b9; 198 | } 199 | .btn-secondary { 200 | background-color: var(--border-color); 201 | color: var(--dark-color); 202 | } 203 | .btn-secondary:hover { 204 | background-color: #dcdfe2; 205 | } 206 | .btn.disabled { 207 | background-color: #e0e0e0; 208 | color: #a0a0a0; 209 | cursor: not-allowed; 210 | pointer-events: none; 211 | } 212 | 213 | 214 | /* --- Footer --- */ 215 | footer { 216 | text-align: center; 217 | padding: 2rem 0; 218 | background: var(--dark-color); 219 | color: var(--light-color); 220 | font-size: 0.9rem; 221 | } 222 | 223 | /* --- Responsive Design --- */ 224 | @media (max-width: 768px) { 225 | header .project-title { font-size: 2.2rem; } 226 | h2 { font-size: 1.8rem; } 227 | .container { padding: 0 1rem; } 228 | 229 | .metrics-container { 230 | flex-direction: column; 231 | } 232 | #live-dashboard .metadata { 233 | flex-direction: column; 234 | align-items: flex-start; 235 | gap: 0.5rem; 236 | } 237 | } -------------------------------------------------------------------------------- /update_predictions.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import re 4 | import subprocess 5 | import time 6 | from datetime import datetime, timezone, timedelta 7 | from pathlib import Path 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import pandas as pd 12 | import torch 13 | from binance.client import Client 14 | 15 | from model import KronosTokenizer, Kronos, KronosPredictor 16 | 17 | # --- Configuration --- 18 | Config = { 19 | "REPO_PATH": Path(__file__).parent.resolve(), 20 | "MODEL_PATH": "../Kronos_model", 21 | "SYMBOL": 'BTCUSDT', 22 | "INTERVAL": '1h', 23 | "HIST_POINTS": 360, 24 | "PRED_HORIZON": 24, 25 | "N_PREDICTIONS": 30, 26 | "VOL_WINDOW": 24, 27 | } 28 | 29 | 30 | def load_model(): 31 | """Loads the Kronos model and tokenizer.""" 32 | print("Loading Kronos model...") 33 | tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-2k", cache_dir=Config["MODEL_PATH"]) 34 | model = Kronos.from_pretrained("NeoQuasar/Kronos-mini", cache_dir=Config["MODEL_PATH"]) 35 | tokenizer.eval() 36 | model.eval() 37 | predictor = KronosPredictor(model, tokenizer, device="cpu", max_context=512) 38 | print("Model loaded successfully.") 39 | return predictor 40 | 41 | 42 | def make_prediction(df, predictor): 43 | """Generates probabilistic forecasts using the Kronos model.""" 44 | last_timestamp = df['timestamps'].max() 45 | start_new_range = last_timestamp + pd.Timedelta(hours=1) 46 | new_timestamps_index = pd.date_range( 47 | start=start_new_range, 48 | periods=Config["PRED_HORIZON"], 49 | freq='H' 50 | ) 51 | y_timestamp = pd.Series(new_timestamps_index, name='y_timestamp') 52 | x_timestamp = df['timestamps'] 53 | x_df = df[['open', 'high', 'low', 'close', 'volume', 'amount']] 54 | 55 | with torch.no_grad(): 56 | print("Making main prediction (T=1.0)...") 57 | begin_time = time.time() 58 | close_preds_main, volume_preds_main = predictor.predict( 59 | df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, 60 | pred_len=Config["PRED_HORIZON"], T=1.0, top_p=0.95, 61 | sample_count=Config["N_PREDICTIONS"], verbose=True 62 | ) 63 | print(f"Main prediction completed in {time.time() - begin_time:.2f} seconds.") 64 | 65 | # print("Making volatility prediction (T=0.9)...") 66 | # begin_time = time.time() 67 | # close_preds_volatility, _ = predictor.predict( 68 | # df=x_df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, 69 | # pred_len=Config["PRED_HORIZON"], T=0.9, top_p=0.9, 70 | # sample_count=Config["N_PREDICTIONS"], verbose=True 71 | # ) 72 | # print(f"Volatility prediction completed in {time.time() - begin_time:.2f} seconds.") 73 | close_preds_volatility = close_preds_main 74 | 75 | return close_preds_main, volume_preds_main, close_preds_volatility 76 | 77 | 78 | def fetch_binance_data(): 79 | """Fetches K-line data from the Binance public API.""" 80 | symbol, interval = Config["SYMBOL"], Config["INTERVAL"] 81 | limit = Config["HIST_POINTS"] + Config["VOL_WINDOW"] 82 | 83 | print(f"Fetching {limit} bars of {symbol} {interval} data from Binance...") 84 | client = Client() 85 | klines = client.get_klines(symbol=symbol, interval=interval, limit=limit) 86 | 87 | cols = ['open_time', 'open', 'high', 'low', 'close', 'volume', 'close_time', 88 | 'quote_asset_volume', 'number_of_trades', 'taker_buy_base_asset_volume', 89 | 'taker_buy_quote_asset_volume', 'ignore'] 90 | df = pd.DataFrame(klines, columns=cols) 91 | 92 | df = df[['open_time', 'open', 'high', 'low', 'close', 'volume', 'quote_asset_volume']] 93 | df.rename(columns={'quote_asset_volume': 'amount', 'open_time': 'timestamps'}, inplace=True) 94 | 95 | df['timestamps'] = pd.to_datetime(df['timestamps'], unit='ms') 96 | for col in ['open', 'high', 'low', 'close', 'volume', 'amount']: 97 | df[col] = pd.to_numeric(df[col]) 98 | 99 | print("Data fetched successfully.") 100 | return df 101 | 102 | 103 | def calculate_metrics(hist_df, close_preds_df, v_close_preds_df): 104 | """ 105 | Calculates upside and volatility amplification probabilities for the 24h horizon. 106 | """ 107 | last_close = hist_df['close'].iloc[-1] 108 | 109 | # 1. Upside Probability (for the 24-hour horizon) 110 | # This is the probability that the price at the end of the horizon is higher than now. 111 | final_hour_preds = close_preds_df.iloc[-1] 112 | upside_prob = (final_hour_preds > last_close).mean() 113 | 114 | # 2. Volatility Amplification Probability (over the 24-hour horizon) 115 | hist_log_returns = np.log(hist_df['close'] / hist_df['close'].shift(1)) 116 | historical_vol = hist_log_returns.iloc[-Config["VOL_WINDOW"]:].std() 117 | 118 | amplification_count = 0 119 | for col in v_close_preds_df.columns: 120 | full_sequence = pd.concat([pd.Series([last_close]), v_close_preds_df[col]]).reset_index(drop=True) 121 | pred_log_returns = np.log(full_sequence / full_sequence.shift(1)) 122 | predicted_vol = pred_log_returns.std() 123 | if predicted_vol > historical_vol: 124 | amplification_count += 1 125 | 126 | vol_amp_prob = amplification_count / len(v_close_preds_df.columns) 127 | 128 | print(f"Upside Probability (24h): {upside_prob:.2%}, Volatility Amplification Probability: {vol_amp_prob:.2%}") 129 | return upside_prob, vol_amp_prob 130 | 131 | 132 | def create_plot(hist_df, close_preds_df, volume_preds_df): 133 | """Generates and saves a comprehensive forecast chart.""" 134 | print("Generating comprehensive forecast chart...") 135 | # plt.style.use('seaborn-v0_8-whitegrid') 136 | fig, (ax1, ax2) = plt.subplots( 137 | 2, 1, figsize=(15, 10), sharex=True, 138 | gridspec_kw={'height_ratios': [3, 1]} 139 | ) 140 | 141 | hist_time = hist_df['timestamps'] 142 | last_hist_time = hist_time.iloc[-1] 143 | pred_time = pd.to_datetime([last_hist_time + timedelta(hours=i + 1) for i in range(len(close_preds_df))]) 144 | 145 | ax1.plot(hist_time, hist_df['close'], color='royalblue', label='Historical Price', linewidth=1.5) 146 | mean_preds = close_preds_df.mean(axis=1) 147 | ax1.plot(pred_time, mean_preds, color='darkorange', linestyle='-', label='Mean Forecast') 148 | ax1.fill_between(pred_time, close_preds_df.min(axis=1), close_preds_df.max(axis=1), color='darkorange', alpha=0.2, label='Forecast Range (Min-Max)') 149 | ax1.set_title(f'{Config["SYMBOL"]} Probabilistic Price & Volume Forecast (Next {Config["PRED_HORIZON"]} Hours)', fontsize=16, weight='bold') 150 | ax1.set_ylabel('Price (USDT)') 151 | ax1.legend() 152 | ax1.grid(True, which='both', linestyle='--', linewidth=0.5) 153 | 154 | ax2.bar(hist_time, hist_df['volume'], color='skyblue', label='Historical Volume', width=0.03) 155 | ax2.bar(pred_time, volume_preds_df.mean(axis=1), color='sandybrown', label='Mean Forecasted Volume', width=0.03) 156 | ax2.set_ylabel('Volume') 157 | ax2.set_xlabel('Time (UTC)') 158 | ax2.legend() 159 | ax2.grid(True, which='both', linestyle='--', linewidth=0.5) 160 | 161 | separator_time = hist_time.iloc[-1] + timedelta(minutes=30) 162 | for ax in [ax1, ax2]: 163 | ax.axvline(x=separator_time, color='red', linestyle='--', linewidth=1.5, label='_nolegend_') 164 | ax.tick_params(axis='x', rotation=30) 165 | 166 | fig.tight_layout() 167 | chart_path = Config["REPO_PATH"] / 'prediction_chart.png' 168 | fig.savefig(chart_path, dpi=120) 169 | plt.close(fig) 170 | print(f"Chart saved to: {chart_path}") 171 | 172 | 173 | def update_html(upside_prob, vol_amp_prob): 174 | """ 175 | Updates the index.html file with the latest metrics and timestamp. 176 | This version uses a more robust lambda function for replacement to avoid formatting errors. 177 | """ 178 | print("Updating index.html...") 179 | html_path = Config["REPO_PATH"] / 'index.html' 180 | now_utc_str = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S') 181 | upside_prob_str = f'{upside_prob:.1%}' 182 | vol_amp_prob_str = f'{vol_amp_prob:.1%}' 183 | 184 | with open(html_path, 'r', encoding='utf-8') as f: 185 | content = f.read() 186 | 187 | # Robustly replace content using lambda functions 188 | content = re.sub( 189 | r'().*?()', 190 | lambda m: f'{m.group(1)}{now_utc_str}{m.group(2)}', 191 | content 192 | ) 193 | content = re.sub( 194 | r'(

).*?(

)', 195 | lambda m: f'{m.group(1)}{upside_prob_str}{m.group(2)}', 196 | content 197 | ) 198 | content = re.sub( 199 | r'(

).*?(

)', 200 | lambda m: f'{m.group(1)}{vol_amp_prob_str}{m.group(2)}', 201 | content 202 | ) 203 | 204 | with open(html_path, 'w', encoding='utf-8') as f: 205 | f.write(content) 206 | print("HTML file updated successfully.") 207 | 208 | 209 | def git_commit_and_push(commit_message): 210 | """Adds, commits, and pushes specified files to the Git repository.""" 211 | print("Performing Git operations...") 212 | try: 213 | os.chdir(Config["REPO_PATH"]) 214 | subprocess.run(['git', 'add', 'prediction_chart.png', 'index.html'], check=True, capture_output=True, text=True) 215 | commit_result = subprocess.run(['git', 'commit', '-m', commit_message], check=True, capture_output=True, text=True) 216 | print(commit_result.stdout) 217 | push_result = subprocess.run(['git', 'push'], check=True, capture_output=True, text=True) 218 | print(push_result.stdout) 219 | print("Git push successful.") 220 | except subprocess.CalledProcessError as e: 221 | output = e.stdout if e.stdout else e.stderr 222 | if "nothing to commit" in output or "Your branch is up to date" in output: 223 | print("No new changes to commit or push.") 224 | else: 225 | print(f"A Git error occurred:\n--- STDOUT ---\n{e.stdout}\n--- STDERR ---\n{e.stderr}") 226 | 227 | 228 | def main_task(model): 229 | """Executes one full update cycle.""" 230 | print("\n" + "=" * 60 + f"\nStarting update task at {datetime.now(timezone.utc)}\n" + "=" * 60) 231 | df_full = fetch_binance_data() 232 | df_for_model = df_full.iloc[:-1] 233 | 234 | close_preds, volume_preds, v_close_preds = make_prediction(df_for_model, model) 235 | 236 | hist_df_for_plot = df_for_model.tail(Config["HIST_POINTS"]) 237 | hist_df_for_metrics = df_for_model.tail(Config["VOL_WINDOW"]) 238 | 239 | upside_prob, vol_amp_prob = calculate_metrics(hist_df_for_metrics, close_preds, v_close_preds) 240 | create_plot(hist_df_for_plot, close_preds, volume_preds) 241 | update_html(upside_prob, vol_amp_prob) 242 | 243 | commit_message = f"Auto-update forecast for {datetime.now(timezone.utc):%Y-%m-%d %H:%M} UTC" 244 | git_commit_and_push(commit_message) 245 | 246 | # --- 新增的内存清理步骤 --- 247 | # 显式删除大的DataFrame对象,帮助垃圾回收器 248 | del df_full, df_for_model, close_preds, volume_preds, v_close_preds 249 | del hist_df_for_plot, hist_df_for_metrics 250 | 251 | # 强制执行垃圾回收 252 | gc.collect() 253 | # --- 内存清理结束 --- 254 | 255 | print("-" * 60 + "\n--- Task completed successfully ---\n" + "-" * 60 + "\n") 256 | 257 | 258 | def run_scheduler(model): 259 | """A continuous scheduler that runs the main task hourly.""" 260 | while True: 261 | now = datetime.now(timezone.utc) 262 | next_run_time = (now + timedelta(hours=1)).replace(minute=0, second=5, microsecond=0) 263 | sleep_seconds = (next_run_time - now).total_seconds() 264 | 265 | if sleep_seconds > 0: 266 | print(f"Current time: {now:%Y-%m-%d %H:%M:%S UTC}.") 267 | print(f"Next run at: {next_run_time:%Y-%m-%d %H:%M:%S UTC}. Waiting for {sleep_seconds:.0f} seconds...") 268 | time.sleep(sleep_seconds) 269 | 270 | try: 271 | main_task(model) 272 | except Exception as e: 273 | print(f"\n!!!!!! A critical error occurred in the main task !!!!!!!") 274 | print(f"Error: {e}") 275 | import traceback 276 | traceback.print_exc() 277 | print("Retrying in 5 minutes...") 278 | print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n") 279 | time.sleep(300) 280 | 281 | 282 | if __name__ == '__main__': 283 | model_path = Path(Config["MODEL_PATH"]) 284 | model_path.mkdir(parents=True, exist_ok=True) 285 | 286 | loaded_model = load_model() 287 | main_task(loaded_model) # Run once on startup 288 | run_scheduler(loaded_model) # Start the schedule -------------------------------------------------------------------------------- /model/kronos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | from huggingface_hub import PyTorchModelHubMixin 5 | import sys 6 | 7 | from tqdm import trange 8 | 9 | sys.path.append("../") 10 | from model.module import * 11 | 12 | 13 | class KronosTokenizer(nn.Module, PyTorchModelHubMixin): 14 | """ 15 | KronosTokenizer module for tokenizing input data using a hybrid quantization approach. 16 | 17 | This tokenizer utilizes a combination of encoder and decoder Transformer blocks 18 | along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data. 19 | 20 | Args: 21 | d_in (int): Input dimension. 22 | d_model (int): Model dimension. 23 | n_heads (int): Number of attention heads. 24 | ff_dim (int): Feed-forward dimension. 25 | n_enc_layers (int): Number of encoder layers. 26 | n_dec_layers (int): Number of decoder layers. 27 | ffn_dropout_p (float): Dropout probability for feed-forward networks. 28 | attn_dropout_p (float): Dropout probability for attention mechanisms. 29 | resid_dropout_p (float): Dropout probability for residual connections. 30 | s1_bits (int): Number of bits for the pre token in BSQuantizer. 31 | s2_bits (int): Number of bits for the post token in BSQuantizer. 32 | beta (float): Beta parameter for BSQuantizer. 33 | gamma0 (float): Gamma0 parameter for BSQuantizer. 34 | gamma (float): Gamma parameter for BSQuantizer. 35 | zeta (float): Zeta parameter for BSQuantizer. 36 | group_size (int): Group size parameter for BSQuantizer. 37 | 38 | """ 39 | 40 | def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, ffn_dropout_p, attn_dropout_p, resid_dropout_p, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): 41 | 42 | super().__init__() 43 | self.d_in = d_in 44 | self.d_model = d_model 45 | self.n_heads = n_heads 46 | self.ff_dim = ff_dim 47 | self.enc_layers = n_enc_layers 48 | self.dec_layers = n_dec_layers 49 | self.ffn_dropout_p = ffn_dropout_p 50 | self.attn_dropout_p = attn_dropout_p 51 | self.resid_dropout_p = resid_dropout_p 52 | 53 | self.s1_bits = s1_bits 54 | self.s2_bits = s2_bits 55 | self.codebook_dim = s1_bits + s2_bits # Total dimension of the codebook after quantization 56 | self.embed = nn.Linear(self.d_in, self.d_model) 57 | self.head = nn.Linear(self.d_model, self.d_in) 58 | 59 | # Encoder Transformer Blocks 60 | self.encoder = nn.ModuleList([ 61 | TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) 62 | for _ in range(self.enc_layers - 1) 63 | ]) 64 | # Decoder Transformer Blocks 65 | self.decoder = nn.ModuleList([ 66 | TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) 67 | for _ in range(self.dec_layers - 1) 68 | ]) 69 | self.quant_embed = nn.Linear(in_features=self.d_model, out_features=self.codebook_dim) # Linear layer before quantization 70 | self.post_quant_embed_pre = nn.Linear(in_features=self.s1_bits, out_features=self.d_model) # Linear layer after quantization (pre part - s1 bits) 71 | self.post_quant_embed = nn.Linear(in_features=self.codebook_dim, out_features=self.d_model) # Linear layer after quantization (full codebook) 72 | self.tokenizer = BSQuantizer(self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size) # BSQuantizer module 73 | 74 | def forward(self, x): 75 | """ 76 | Forward pass of the KronosTokenizer. 77 | 78 | Args: 79 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). 80 | 81 | Returns: 82 | tuple: A tuple containing: 83 | - tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively, 84 | both of shape (batch_size, seq_len, d_in). 85 | - torch.Tensor: bsq_loss - Loss from the BSQuantizer. 86 | - torch.Tensor: quantized - Quantized representation from BSQuantizer. 87 | - torch.Tensor: z_indices - Indices from the BSQuantizer. 88 | """ 89 | z = self.embed(x) 90 | 91 | for layer in self.encoder: 92 | z = layer(z) 93 | 94 | z = self.quant_embed(z) # (B, T, codebook) 95 | 96 | bsq_loss, quantized, z_indices = self.tokenizer(z) 97 | 98 | quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits) 99 | z_pre = self.post_quant_embed_pre(quantized_pre) 100 | 101 | z = self.post_quant_embed(quantized) 102 | 103 | # Decoder layers (for pre part - s1 bits) 104 | for layer in self.decoder: 105 | z_pre = layer(z_pre) 106 | z_pre = self.head(z_pre) 107 | 108 | # Decoder layers (for full codebook) 109 | for layer in self.decoder: 110 | z = layer(z) 111 | z = self.head(z) 112 | 113 | return (z_pre, z), bsq_loss, quantized, z_indices 114 | 115 | def indices_to_bits(self, x, half=False): 116 | """ 117 | Converts indices to bit representations and scales them. 118 | 119 | Args: 120 | x (torch.Tensor): Indices tensor. 121 | half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False. 122 | 123 | Returns: 124 | torch.Tensor: Bit representation tensor. 125 | """ 126 | if half: 127 | x1 = x[0] # Assuming x is a tuple of indices if half is True 128 | x2 = x[1] 129 | mask = 2 ** torch.arange(self.codebook_dim//2, device=x1.device, dtype=torch.long) # Create a mask for bit extraction 130 | x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half 131 | x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half 132 | x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations 133 | else: 134 | mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # Create a mask for bit extraction 135 | x = (x.unsqueeze(-1) & mask) != 0 # Extract bits 136 | 137 | x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1) 138 | q_scale = 1. / (self.codebook_dim ** 0.5) # Scaling factor 139 | x = x * q_scale 140 | return x 141 | 142 | def encode(self, x, half=False): 143 | """ 144 | Encodes the input data into quantized indices. 145 | 146 | Args: 147 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). 148 | half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False. 149 | 150 | Returns: 151 | torch.Tensor: Quantized indices from BSQuantizer. 152 | """ 153 | z = self.embed(x) 154 | for layer in self.encoder: 155 | z = layer(z) 156 | z = self.quant_embed(z) 157 | 158 | bsq_loss, quantized, z_indices = self.tokenizer(z, half) 159 | return z_indices 160 | 161 | def decode(self, x, half=False): 162 | """ 163 | Decodes quantized indices back to the input data space. 164 | 165 | Args: 166 | x (torch.Tensor): Quantized indices tensor. 167 | half (bool, optional): Whether the indices were generated with half quantization. Defaults to False. 168 | 169 | Returns: 170 | torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in). 171 | """ 172 | quantized = self.indices_to_bits(x, half) 173 | z = self.post_quant_embed(quantized) 174 | for layer in self.decoder: 175 | z = layer(z) 176 | z = self.head(z) 177 | return z 178 | 179 | 180 | class Kronos(nn.Module, PyTorchModelHubMixin): 181 | """ 182 | Kronos Model. 183 | 184 | Args: 185 | s1_bits (int): Number of bits for pre tokens. 186 | s2_bits (int): Number of bits for post tokens. 187 | n_layers (int): Number of Transformer blocks. 188 | d_model (int): Dimension of the model's embeddings and hidden states. 189 | n_heads (int): Number of attention heads in the MultiheadAttention layers. 190 | ff_dim (int): Dimension of the feedforward network in the Transformer blocks. 191 | ffn_dropout_p (float): Dropout probability for the feedforward network. 192 | attn_dropout_p (float): Dropout probability for the attention layers. 193 | resid_dropout_p (float): Dropout probability for residual connections. 194 | token_dropout_p (float): Dropout probability for token embeddings. 195 | learn_te (bool): Whether to use learnable temporal embeddings. 196 | """ 197 | 198 | def __init__(self, s1_bits, s2_bits, n_layers, d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p, token_dropout_p, learn_te): 199 | super().__init__() 200 | self.s1_bits = s1_bits 201 | self.s2_bits = s2_bits 202 | self.n_layers = n_layers 203 | self.d_model = d_model 204 | self.n_heads = n_heads 205 | self.learn_te = learn_te 206 | self.ff_dim = ff_dim 207 | self.ffn_dropout_p = ffn_dropout_p 208 | self.attn_dropout_p = attn_dropout_p 209 | self.resid_dropout_p = resid_dropout_p 210 | self.token_dropout_p = token_dropout_p 211 | 212 | self.s1_vocab_size = 2 ** self.s1_bits 213 | self.token_drop = nn.Dropout(self.token_dropout_p) 214 | self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model) 215 | self.time_emb = TemporalEmbedding(self.d_model, self.learn_te) 216 | self.transformer = nn.ModuleList([ 217 | TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) 218 | for _ in range(self.n_layers) 219 | ]) 220 | self.norm = RMSNorm(self.d_model) 221 | self.dep_layer = DependencyAwareLayer(self.d_model) 222 | self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model) 223 | self.apply(self._init_weights) 224 | 225 | def _init_weights(self, module): 226 | 227 | if isinstance(module, nn.Linear): 228 | nn.init.xavier_normal_(module.weight) 229 | if module.bias is not None: 230 | nn.init.zeros_(module.bias) 231 | elif isinstance(module, nn.Embedding): 232 | nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model ** -0.5) 233 | elif isinstance(module, nn.LayerNorm): 234 | nn.init.ones_(module.weight) 235 | nn.init.zeros_(module.bias) 236 | elif isinstance(module, RMSNorm): 237 | nn.init.ones_(module.weight) 238 | 239 | def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None): 240 | """ 241 | Args: 242 | s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] 243 | s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] 244 | stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. 245 | padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. 246 | use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False. 247 | s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None. 248 | 249 | Returns: 250 | Tuple[torch.Tensor, torch.Tensor]: 251 | - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] 252 | - s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size] 253 | """ 254 | x = self.embedding([s1_ids, s2_ids]) 255 | if stamp is not None: 256 | time_embedding = self.time_emb(stamp) 257 | x = x + time_embedding 258 | x = self.token_drop(x) 259 | 260 | for layer in self.transformer: 261 | x = layer(x, key_padding_mask=padding_mask) 262 | 263 | x = self.norm(x) 264 | 265 | s1_logits = self.head(x) 266 | 267 | if use_teacher_forcing: 268 | sibling_embed = self.embedding.emb_s1(s1_targets) 269 | else: 270 | s1_probs = F.softmax(s1_logits.detach(), dim=-1) 271 | sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape) 272 | sibling_embed = self.embedding.emb_s1(sample_s1_ids) 273 | 274 | x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings 275 | s2_logits = self.head.cond_forward(x2) 276 | return s1_logits, s2_logits 277 | 278 | def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None): 279 | """ 280 | Decodes only the s1 tokens. 281 | 282 | This method performs a forward pass to predict only s1 tokens. It returns the s1 logits 283 | and the context representation from the Transformer, which can be used for subsequent s2 decoding. 284 | 285 | Args: 286 | s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] 287 | s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] 288 | stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. 289 | padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. 290 | 291 | Returns: 292 | Tuple[torch.Tensor, torch.Tensor]: 293 | - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] 294 | - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model] 295 | """ 296 | x = self.embedding([s1_ids, s2_ids]) 297 | if stamp is not None: 298 | time_embedding = self.time_emb(stamp) 299 | x = x + time_embedding 300 | x = self.token_drop(x) 301 | 302 | for layer in self.transformer: 303 | x = layer(x, key_padding_mask=padding_mask) 304 | 305 | x = self.norm(x) 306 | 307 | s1_logits = self.head(x) 308 | return s1_logits, x 309 | 310 | def decode_s2(self, context, s1_ids, padding_mask=None): 311 | """ 312 | Decodes the s2 tokens, conditioned on the context and s1 tokens. 313 | 314 | This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`) 315 | and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens. 316 | 317 | Args: 318 | context (torch.Tensor): Context representation from the transformer (output of decode_s1). 319 | Shape: [batch_size, seq_len, d_model] 320 | s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] 321 | padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. 322 | 323 | Returns: 324 | torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size] 325 | """ 326 | sibling_embed = self.embedding.emb_s1(s1_ids) 327 | x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask) 328 | return self.head.cond_forward(x2) 329 | 330 | 331 | def top_k_top_p_filtering( 332 | logits, 333 | top_k: int = 0, 334 | top_p: float = 1.0, 335 | filter_value: float = -float("Inf"), 336 | min_tokens_to_keep: int = 1, 337 | ): 338 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 339 | Args: 340 | logits: logits distribution shape (batch size, vocabulary size) 341 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 342 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 343 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 344 | Make sure we keep at least min_tokens_to_keep per batch example in the output 345 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 346 | """ 347 | if top_k > 0: 348 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 349 | # Remove all tokens with a probability less than the last token of the top-k 350 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 351 | logits[indices_to_remove] = filter_value 352 | return logits 353 | 354 | if top_p < 1.0: 355 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 356 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 357 | 358 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 359 | sorted_indices_to_remove = cumulative_probs > top_p 360 | if min_tokens_to_keep > 1: 361 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 362 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 363 | # Shift the indices to the right to keep also the first token above the threshold 364 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 365 | sorted_indices_to_remove[..., 0] = 0 366 | 367 | # scatter sorted tensors to original indexing 368 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 369 | logits[indices_to_remove] = filter_value 370 | return logits 371 | 372 | 373 | def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True): 374 | logits = logits / temperature 375 | if top_k is not None or top_p is not None: 376 | if top_k > 0 or top_p < 1.0: 377 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 378 | 379 | probs = F.softmax(logits, dim=-1) 380 | 381 | if not sample_logits: 382 | _, x = top_k(probs, k=1, dim=-1) 383 | else: 384 | x = torch.multinomial(probs, num_samples=1) 385 | 386 | return x 387 | 388 | 389 | def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False): 390 | with torch.no_grad(): 391 | batch_size = x.size(0) 392 | initial_seq_len = x.size(1) 393 | x = torch.clip(x, -clip, clip) 394 | 395 | device = x.device 396 | x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device) 397 | x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device) 398 | y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device) 399 | 400 | x_token = tokenizer.encode(x, half=True) 401 | 402 | def get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, pred_step): 403 | 404 | if current_seq_len <= max_context - pred_step: 405 | return torch.cat([x_stamp, y_stamp[:, :pred_step, :]], dim=1) 406 | else: 407 | start_idx = max_context - pred_step 408 | return torch.cat([x_stamp[:, -start_idx:, :], y_stamp[:, :pred_step, :]], dim=1) 409 | 410 | if verbose: 411 | ran = trange 412 | else: 413 | ran = range 414 | for i in ran(pred_len): 415 | current_seq_len = initial_seq_len + i 416 | 417 | if current_seq_len <= max_context: 418 | input_tokens = x_token 419 | else: 420 | input_tokens = [t[:, -max_context:].contiguous() for t in x_token] 421 | 422 | current_stamp = get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, i) 423 | 424 | s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp) 425 | s1_logits = s1_logits[:, -1, :] 426 | sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) 427 | 428 | s2_logits = model.decode_s2(context, sample_pre) 429 | s2_logits = s2_logits[:, -1, :] 430 | sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) 431 | 432 | x_token[0] = torch.cat([x_token[0], sample_pre], dim=1) 433 | x_token[1] = torch.cat([x_token[1], sample_post], dim=1) 434 | 435 | input_tokens = [t[:, -max_context:].contiguous() for t in x_token] 436 | z = tokenizer.decode(input_tokens, half=True) 437 | z = z.reshape(batch_size, sample_count, z.size(1), z.size(2)) 438 | preds = z.cpu().numpy() 439 | # preds = np.mean(preds, axis=1) 440 | 441 | return preds 442 | 443 | 444 | def calc_time_stamps(x_timestamp): 445 | time_df = pd.DataFrame() 446 | time_df['minute'] = x_timestamp.dt.minute 447 | time_df['hour'] = x_timestamp.dt.hour 448 | time_df['weekday'] = x_timestamp.dt.weekday 449 | time_df['day'] = x_timestamp.dt.day 450 | time_df['month'] = x_timestamp.dt.month 451 | return time_df 452 | 453 | 454 | class KronosPredictor: 455 | 456 | def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5): 457 | self.tokenizer = tokenizer 458 | self.model = model 459 | self.max_context = max_context 460 | self.clip = clip 461 | self.price_cols = ['open', 'high', 'low', 'close'] 462 | self.vol_col = 'volume' 463 | self.amt_vol = 'amount' 464 | self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month'] 465 | self.device = device 466 | 467 | self.tokenizer = self.tokenizer.to(self.device) 468 | self.model = self.model.to(self.device) 469 | 470 | def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose): 471 | 472 | x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device) 473 | x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device) 474 | y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device) 475 | 476 | preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len, 477 | self.clip, T, top_k, top_p, sample_count, verbose) 478 | preds = preds[:, :, -pred_len:, :] 479 | return preds 480 | 481 | def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True): 482 | 483 | if not isinstance(df, pd.DataFrame): 484 | raise ValueError("Input must be a pandas DataFrame.") 485 | 486 | if not all(col in df.columns for col in self.price_cols): 487 | raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.") 488 | 489 | df = df.copy() 490 | if self.vol_col not in df.columns: 491 | df[self.vol_col] = 0.0 # Fill missing volume with zeros 492 | df[self.amt_vol] = 0.0 # Fill missing amount with zeros 493 | if self.amt_vol not in df.columns and self.vol_col in df.columns: 494 | df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) 495 | 496 | if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): 497 | raise ValueError("Input DataFrame contains NaN values in price or volume columns.") 498 | 499 | x_time_df = calc_time_stamps(x_timestamp) 500 | y_time_df = calc_time_stamps(y_timestamp) 501 | 502 | x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) 503 | x_stamp = x_time_df.values.astype(np.float32) 504 | y_stamp = y_time_df.values.astype(np.float32) 505 | 506 | x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) 507 | 508 | x = (x - x_mean) / (x_std + 1e-5) 509 | x = np.clip(x, -self.clip, self.clip) 510 | 511 | x = x[np.newaxis, :] 512 | x_stamp = x_stamp[np.newaxis, :] 513 | y_stamp = y_stamp[np.newaxis, :] 514 | 515 | preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose) 516 | 517 | preds = preds.squeeze(0) 518 | preds = preds * (x_std[np.newaxis, :] + 1e-5) + x_mean[np.newaxis, :] 519 | 520 | close_preds = preds[:, :, 3].swapaxes(0, 1) 521 | volume_preds = preds[:, :, 4].swapaxes(0, 1) 522 | 523 | close_df = pd.DataFrame(close_preds, columns=[f"pred-{i+1}" for i in range(sample_count)], index=y_timestamp) 524 | volume_df = pd.DataFrame(volume_preds, columns=[f"pred-{i + 1}" for i in range(sample_count)], index=y_timestamp) 525 | 526 | return close_df, volume_df 527 | -------------------------------------------------------------------------------- /model/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from einops import rearrange, reduce 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | import torch.nn.functional as F 8 | 9 | 10 | class DifferentiableEntropyFunction(Function): 11 | @staticmethod 12 | def forward(ctx, zq, basis, K, eps): 13 | zb = (zq + 1) / 2 14 | zi = ((zb * basis).sum(-1)).to(torch.int64) 15 | cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype), 16 | 0, 17 | zi.flatten(), 18 | torch.ones_like(zi.flatten()).to(zq.dtype), 19 | 'sum') 20 | prob = (cnt + eps) / (cnt + eps).sum() 21 | H = -(prob * torch.log(prob)).sum() 22 | ctx.save_for_backward(zq, zi, prob) 23 | ctx.K = K 24 | return H 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | zq, zi, prob = ctx.saved_tensors 29 | grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K 30 | reord_grad = grad_array[zi.flatten()].reshape(zi.shape) 31 | grad_input = reord_grad.unsqueeze(-1) * zq 32 | return grad_input, None, None, None, None 33 | 34 | 35 | def codebook_entropy(zq, basis, K, eps=1e-4): 36 | return DifferentiableEntropyFunction.apply(zq, basis, K, eps) 37 | 38 | 39 | class BinarySphericalQuantizer(nn.Module): 40 | def __init__(self, embed_dim, beta, gamma0, gamma, zeta, 41 | input_format='bchw', 42 | soft_entropy=True, group_size=9, 43 | persample_entropy_compute='analytical', 44 | cb_entropy_compute='group', 45 | l2_norm=True, 46 | inv_temperature=1): 47 | """ 48 | Paper link: https://arxiv.org/pdf/2406.07548.pdf 49 | Here we use the official implementation of the BinarySphericalQuantizer. 50 | """ 51 | super().__init__() 52 | self.embed_dim = embed_dim 53 | self.beta = beta # loss weight for commit loss 54 | self.gamma0 = gamma0 # loss weight for entropy penalty 55 | self.gamma = gamma # loss weight for entropy penalty 56 | self.zeta = zeta # loss weight for entire entropy penalty 57 | self.input_format = input_format 58 | assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size" 59 | self.num_groups = self.embed_dim // group_size 60 | self.group_size = group_size 61 | assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'" 62 | assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'" 63 | self.persample_entropy_compute = persample_entropy_compute 64 | self.cb_entropy_compute = cb_entropy_compute 65 | self.l2_norm = l2_norm 66 | self.inv_temperature = inv_temperature 67 | 68 | self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1)) 69 | self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1)) 70 | 71 | self.num_dimensions = 2 ** embed_dim 72 | self.bits_per_index = embed_dim 73 | 74 | # we only need to keep the codebook portion up to the group size 75 | # because we approximate the H loss with this subcode 76 | group_codes = torch.arange(2 ** self.group_size) 77 | group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:] 78 | self.register_buffer('group_codebook', group_codebook, persistent=False) 79 | 80 | self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf 81 | 82 | def quantize(self, z): 83 | assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}" 84 | 85 | zhat = torch.where(z > 0, 86 | torch.tensor(1, dtype=z.dtype, device=z.device), 87 | torch.tensor(-1, dtype=z.dtype, device=z.device)) 88 | return z + (zhat - z).detach() 89 | 90 | def forward(self, z): 91 | # if self.input_format == 'bchw': 92 | # z = rearrange(z, 'b c h w -> b h w c') 93 | zq = self.quantize(z) 94 | 95 | indices = self.codes_to_indexes(zq.detach()) 96 | group_indices = self.codes_to_group_indexes(zq.detach()) 97 | if not self.training: 98 | used_codes = torch.unique(indices, return_counts=False) 99 | else: 100 | used_codes = None 101 | 102 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. 103 | 104 | if self.soft_entropy: 105 | persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z) 106 | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy 107 | else: 108 | zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32) 109 | persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample) 110 | cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim) 111 | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy 112 | 113 | zq = zq * q_scale 114 | 115 | # commit loss 116 | commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1)) 117 | 118 | # if self.input_format == 'bchw': 119 | # zq = rearrange(zq, 'b h w c -> b c h w') 120 | 121 | return ( 122 | zq, 123 | commit_loss + self.zeta * entropy_penalty / self.inv_temperature, 124 | {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices, 125 | "avg_prob": avg_prob} 126 | ) 127 | 128 | def soft_entropy_loss(self, z): 129 | # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size 130 | # the sub-code is the last group_size bits of the full code 131 | group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1) 132 | divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size) 133 | 134 | # we calculate the distance between the divided_z and the codebook for each subgroup 135 | distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book) 136 | prob = (-distance * self.inv_temperature).softmax(dim=-1) 137 | if self.persample_entropy_compute == 'analytical': 138 | if self.l2_norm: 139 | p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature) 140 | else: 141 | p = torch.sigmoid(-4 * z * self.inv_temperature) 142 | prob = torch.stack([p, 1 - p], dim=-1) 143 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() 144 | else: 145 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() 146 | 147 | # macro average of the probability of each subgroup 148 | avg_prob = reduce(prob, '... g d ->g d', 'mean') 149 | codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) 150 | 151 | # the approximation of the entropy is the sum of the entropy of each subgroup 152 | return per_sample_entropy, codebook_entropy.sum(), avg_prob 153 | 154 | def get_hard_per_sample_entropy(self, zb_by_sample): 155 | probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1] 156 | persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8) 157 | persample_entropy = persample_entropy.sum(-1) 158 | return persample_entropy.mean() 159 | 160 | def codes_to_indexes(self, zhat): 161 | """Converts a `code` to an index in the codebook. 162 | Args: 163 | zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} 164 | """ 165 | assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}" 166 | return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64) 167 | 168 | def codes_to_group_indexes(self, zhat): 169 | """Converts a `code` to a list of indexes (in groups) in the codebook. 170 | Args: 171 | zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} 172 | """ 173 | zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size) 174 | return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64) 175 | 176 | def indexes_to_codes(self, indices): 177 | """Inverse of `indexes_to_codes`.""" 178 | indices = indices.unsqueeze(-1) 179 | codes_non_centered = torch.remainder( 180 | torch.floor_divide(indices, self.basis), 2 181 | ) 182 | return codes_non_centered * 2 - 1 183 | 184 | def group_indexes_to_codes(self, group_indices): 185 | """Inverse of `group_indexes_to_codes`.""" 186 | group_indices = group_indices.unsqueeze(-1) 187 | codes_non_centered = torch.remainder( 188 | torch.floor_divide(group_indices, self.group_basis), 2 189 | ) 190 | codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)') 191 | return codes_non_centered * 2 - 1 192 | 193 | def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): 194 | if normalize: 195 | probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True) 196 | else: 197 | probs = count 198 | H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) 199 | return H 200 | 201 | def get_group_codebook_entry(self, group_indices): 202 | z_q = self.group_indexes_to_codes(group_indices) 203 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. 204 | z_q = z_q * q_scale 205 | if self.input_format == 'bchw': 206 | h, w = int(z_q.shape[1] ** 0.5) 207 | assert h * w == z_q.shape[1], 'Invalid sequence length' 208 | z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) 209 | return z_q 210 | 211 | def get_codebook_entry(self, indices): 212 | z_q = self.indexes_to_codes(indices) 213 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. 214 | z_q = z_q * q_scale 215 | if self.input_format == 'bchw': 216 | h, w = int(z_q.shape[1] ** 0.5) 217 | assert h * w == z_q.shape[1], 'Invalid sequence length' 218 | z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) 219 | return z_q 220 | 221 | 222 | class BSQuantizer(nn.Module): 223 | 224 | def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): 225 | super().__init__() 226 | self.codebook_dim = s1_bits + s2_bits 227 | self.s1_bits = s1_bits 228 | self.s2_bits = s2_bits 229 | self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size) 230 | 231 | def bits_to_indices(self, bits): 232 | bits = (bits >= 0).to(torch.long) 233 | indices = 2 ** torch.arange( 234 | 0, 235 | bits.shape[-1], 236 | 1, 237 | dtype=torch.long, 238 | device=bits.device, 239 | ) 240 | return (bits * indices).sum(-1) 241 | 242 | def forward(self, z, half=False): 243 | z = F.normalize(z, dim=-1) 244 | quantized, bsq_loss, metrics = self.bsq(z) 245 | if half: 246 | q_pre = quantized[:, :, :self.s1_bits] 247 | q_post = quantized[:, :, self.s1_bits:] 248 | z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)] 249 | else: 250 | z_indices = self.bits_to_indices(quantized) 251 | return bsq_loss, quantized, z_indices 252 | 253 | 254 | class RMSNorm(torch.nn.Module): 255 | def __init__(self, dim: int, eps: float = 1e-5): 256 | super().__init__() 257 | self.eps = eps 258 | self.weight = nn.Parameter(torch.ones(dim)) 259 | 260 | def _norm(self, x): 261 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 262 | 263 | def forward(self, x): 264 | output = self._norm(x.float()).type_as(x) 265 | return output * self.weight 266 | 267 | 268 | class FeedForward(nn.Module): 269 | def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0): 270 | super().__init__() 271 | 272 | self.w1 = nn.Linear(d_model, ff_dim, bias=False) 273 | self.w3 = nn.Linear(d_model, ff_dim, bias=False) 274 | self.w2 = nn.Linear(ff_dim, d_model, bias=False) 275 | self.ffn_dropout = nn.Dropout(ffn_dropout_p) 276 | 277 | def forward(self, x): 278 | return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) 279 | 280 | 281 | class RotaryPositionalEmbedding(nn.Module): 282 | def __init__(self, dim): 283 | super().__init__() 284 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 285 | self.register_buffer("inv_freq", inv_freq) 286 | self.seq_len_cached = None 287 | self.cos_cached = None 288 | self.sin_cached = None 289 | 290 | def _update_cos_sin_cache(self, x, seq_len): 291 | if seq_len != self.seq_len_cached: 292 | self.seq_len_cached = seq_len 293 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) 294 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 295 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 296 | self.cos_cached = emb.cos()[None, None, :, :] 297 | self.sin_cached = emb.sin()[None, None, :, :] 298 | return self.cos_cached, self.sin_cached 299 | 300 | def forward(self, q, k): 301 | cos, sin = self._update_cos_sin_cache(q, q.shape[-2]) 302 | return ( 303 | (q * cos) + (self._rotate_half(q) * sin), 304 | (k * cos) + (self._rotate_half(k) * sin), 305 | ) 306 | 307 | def _rotate_half(self, x): 308 | x1, x2 = x.chunk(2, dim=-1) 309 | return torch.cat((-x2, x1), dim=-1) 310 | 311 | 312 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, training=True) -> torch.Tensor: 313 | L, S = query.size(-2), key.size(-2) 314 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 315 | attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device) 316 | 317 | if is_causal: 318 | assert attn_mask is None 319 | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(query.device) 320 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 321 | attn_bias.to(query.dtype) 322 | 323 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 324 | attn_weight += attn_bias 325 | 326 | if attn_mask is not None: 327 | attn_mask_bias = torch.zeros_like(attn_weight) 328 | if attn_mask.dtype == torch.bool: 329 | attn_mask_bias.masked_fill_(attn_mask, float("-inf")) 330 | else: 331 | attn_mask_bias += attn_mask 332 | attn_weight += attn_mask_bias 333 | 334 | attn_weight = torch.softmax(attn_weight, dim=-1) 335 | attn_weight = torch.dropout(attn_weight, dropout_p, train=training) 336 | return attn_weight @ value 337 | 338 | 339 | class MultiHeadAttentionWithRoPE(nn.Module): 340 | def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0): 341 | super().__init__() 342 | self.d_model = d_model 343 | self.n_heads = n_heads 344 | self.head_dim = d_model // n_heads 345 | 346 | self.q_proj = nn.Linear(d_model, d_model) 347 | self.k_proj = nn.Linear(d_model, d_model) 348 | self.v_proj = nn.Linear(d_model, d_model) 349 | self.out_proj = nn.Linear(d_model, d_model) 350 | self.rotary = RotaryPositionalEmbedding(self.head_dim) 351 | self.attn_dropout_p = attn_dropout_p 352 | self.resid_dropout = nn.Dropout(resid_dropout_p) 353 | 354 | def forward(self, x, key_padding_mask=None): 355 | batch_size, seq_len, _ = x.shape 356 | 357 | q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 358 | k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 359 | v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 360 | 361 | q, k = self.rotary(q, k) 362 | 363 | if key_padding_mask is not None: 364 | attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len] 365 | attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len] 366 | else: 367 | attn_mask = None 368 | 369 | attn_output = scaled_dot_product_attention( 370 | q, k, v, 371 | attn_mask=attn_mask, 372 | dropout_p=self.attn_dropout_p, 373 | is_causal=True, 374 | training=self.training 375 | ) 376 | 377 | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) 378 | return self.resid_dropout(self.out_proj(attn_output)) 379 | 380 | 381 | class MultiHeadCrossAttentionWithRoPE(nn.Module): 382 | def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0): 383 | super().__init__() 384 | self.d_model = d_model 385 | self.n_heads = n_heads 386 | self.head_dim = d_model // n_heads 387 | 388 | self.q_proj = nn.Linear(d_model, d_model) 389 | self.k_proj = nn.Linear(d_model, d_model) 390 | self.v_proj = nn.Linear(d_model, d_model) 391 | self.out_proj = nn.Linear(d_model, d_model) 392 | self.rotary = RotaryPositionalEmbedding(self.head_dim) 393 | self.attn_dropout_p = attn_dropout_p 394 | self.resid_dropout = nn.Dropout(resid_dropout) 395 | 396 | def forward(self, query, key, value, key_padding_mask=None): 397 | batch_size, q_len, _ = query.shape 398 | _, seq_len, _ = key.shape 399 | 400 | q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2) 401 | k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 402 | v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 403 | 404 | q, k = self.rotary(q, k) 405 | 406 | if key_padding_mask is not None: 407 | attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) 408 | attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1) 409 | else: 410 | attn_mask = None 411 | 412 | is_causal_flag = self.training 413 | 414 | attn_output = scaled_dot_product_attention( 415 | q, k, v, 416 | attn_mask=attn_mask, 417 | dropout_p=self.attn_dropout_p, 418 | is_causal=is_causal_flag, 419 | training=self.training 420 | ) 421 | 422 | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model) 423 | return self.resid_dropout(self.out_proj(attn_output)) 424 | 425 | 426 | class HierarchicalEmbedding(nn.Module): 427 | def __init__(self, s1_bits, s2_bits, d_model=256): 428 | super().__init__() 429 | self.s1_bits = s1_bits 430 | self.s2_bits = s2_bits 431 | 432 | vocab_s1 = 2 ** s1_bits 433 | vocab_s2 = 2 ** s2_bits 434 | 435 | self.emb_s1 = nn.Embedding(vocab_s1, d_model) 436 | self.emb_s2 = nn.Embedding(vocab_s2, d_model) 437 | self.d_model = d_model 438 | self.fusion_proj = nn.Linear(d_model * 2, d_model) 439 | 440 | nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5) 441 | nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5) 442 | 443 | def forward(self, token_ids): 444 | """Inputs: 445 | token_ids: [batch_size, seq_len] token ID 446 | Output: [batch_size, seq_len, d_model] 447 | """ 448 | if isinstance(token_ids, tuple) or isinstance(token_ids, list): 449 | s1_ids, s2_ids = token_ids 450 | else: 451 | s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits) 452 | s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model) 453 | s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model) 454 | return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1)) 455 | 456 | 457 | class DependencyAwareLayer(nn.Module): 458 | def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0): 459 | super().__init__() 460 | self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout) 461 | self.norm = RMSNorm(d_model) 462 | 463 | def forward(self, hidden_states, sibling_embed, key_padding_mask=None): 464 | """hidden_states: [batch, seq_len, d_model] 465 | sibling_embed: Embedding from another subtoken 466 | """ 467 | attn_out = self.cross_attn( 468 | query=sibling_embed, 469 | key=hidden_states, 470 | value=hidden_states, 471 | key_padding_mask=key_padding_mask 472 | ) 473 | return self.norm(hidden_states + attn_out) 474 | 475 | 476 | class TransformerBlock(nn.Module): 477 | def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0): 478 | super().__init__() 479 | self.norm1 = RMSNorm(d_model) 480 | self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p) 481 | self.norm2 = RMSNorm(d_model) 482 | self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p) 483 | 484 | def forward(self, x, key_padding_mask=None): 485 | residual = x 486 | x = self.norm1(x) 487 | attn_out = self.self_attn(x, key_padding_mask=key_padding_mask) 488 | x = residual + attn_out 489 | 490 | residual = x 491 | x = self.norm2(x) 492 | ffn_out = self.ffn(x) 493 | x = residual + ffn_out 494 | return x 495 | 496 | 497 | class DualHead(nn.Module): 498 | def __init__(self, s1_bits, s2_bits, d_model): 499 | super().__init__() 500 | self.vocab_s1 = 2 ** s1_bits 501 | self.vocab_s2 = 2 ** s2_bits 502 | self.proj_s1 = nn.Linear(d_model, self.vocab_s1) 503 | self.proj_s2 = nn.Linear(d_model, self.vocab_s2) 504 | 505 | def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None): 506 | if padding_mask is not None: 507 | valid_mask = (padding_mask == 0) 508 | s1_logits = s1_logits[valid_mask] 509 | s2_logits = s2_logits[valid_mask] 510 | s1_targets = s1_targets[valid_mask] 511 | s2_targets = s2_targets[valid_mask] 512 | ce_s1 = F.cross_entropy(s1_logits, s1_targets) 513 | ce_s2 = F.cross_entropy(s2_logits, s2_targets) 514 | else: 515 | ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1)) 516 | ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1)) 517 | ce_loss = (ce_s1 + ce_s2) / 2 518 | return ce_loss, ce_s1, ce_s2 519 | 520 | def forward(self, x): 521 | return self.proj_s1(x) 522 | 523 | def cond_forward(self, x2): 524 | return self.proj_s2(x2) 525 | 526 | 527 | class FixedEmbedding(nn.Module): 528 | def __init__(self, c_in, d_model): 529 | super(FixedEmbedding, self).__init__() 530 | 531 | w = torch.zeros(c_in, d_model).float() 532 | w.require_grad = False 533 | 534 | position = torch.arange(0, c_in).float().unsqueeze(1) 535 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 536 | 537 | w[:, 0::2] = torch.sin(position * div_term) 538 | w[:, 1::2] = torch.cos(position * div_term) 539 | 540 | self.emb = nn.Embedding(c_in, d_model) 541 | self.emb.weight = nn.Parameter(w, requires_grad=False) 542 | 543 | def forward(self, x): 544 | return self.emb(x).detach() 545 | 546 | 547 | class TemporalEmbedding(nn.Module): 548 | def __init__(self, d_model, learn_pe): 549 | super(TemporalEmbedding, self).__init__() 550 | 551 | minute_size = 60 552 | hour_size = 24 553 | weekday_size = 7 554 | day_size = 32 555 | month_size = 13 556 | 557 | Embed = FixedEmbedding if not learn_pe else nn.Embedding 558 | self.minute_embed = Embed(minute_size, d_model) 559 | self.hour_embed = Embed(hour_size, d_model) 560 | self.weekday_embed = Embed(weekday_size, d_model) 561 | self.day_embed = Embed(day_size, d_model) 562 | self.month_embed = Embed(month_size, d_model) 563 | 564 | def forward(self, x): 565 | x = x.long() 566 | 567 | minute_x = self.minute_embed(x[:, :, 0]) 568 | hour_x = self.hour_embed(x[:, :, 1]) 569 | weekday_x = self.weekday_embed(x[:, :, 2]) 570 | day_x = self.day_embed(x[:, :, 3]) 571 | month_x = self.month_embed(x[:, :, 4]) 572 | 573 | return hour_x + weekday_x + day_x + month_x + minute_x 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | --------------------------------------------------------------------------------