├── finetune ├── utils │ ├── __init__.py │ └── training_utils.py ├── qlib_data_preprocess.py ├── dataset.py ├── config.py ├── train_predictor.py ├── train_tokenizer.py └── qlib_test.py ├── figures ├── logo.png ├── overview.png ├── prediction_example.png └── backtest_result_example.png ├── webui ├── requirements.txt ├── start.sh ├── run.py └── README.md ├── requirements.txt ├── finetune_csv ├── examples │ ├── HK_ali_09988_kline_5min_all_historical_20250919_073929.png │ ├── HK_ali_09988_kline_5min_all_historical_20250919_073944.png │ ├── HK_ali_09988_kline_5min_all_historical_20250919_074012.png │ ├── HK_ali_09988_kline_5min_all_historical_20250919_074042.png │ └── HK_ali_09988_kline_5min_all_historical_20250919_074251.png ├── configs │ └── config_ali09988_candle-5min.yaml ├── README_CN.md ├── README.md ├── config_loader.py ├── finetune_tokenizer.py ├── train_sequential.py └── finetune_base_model.py ├── model ├── __init__.py └── module.py ├── .gitignore ├── LICENSE ├── examples ├── prediction_wo_vol_example.py ├── prediction_batch_example.py └── prediction_example.py └── README.md /finetune/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/figures/logo.png -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/figures/overview.png -------------------------------------------------------------------------------- /figures/prediction_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/figures/prediction_example.png -------------------------------------------------------------------------------- /figures/backtest_result_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/figures/backtest_result_example.png -------------------------------------------------------------------------------- /webui/requirements.txt: -------------------------------------------------------------------------------- 1 | flask==2.3.3 2 | flask-cors==4.0.0 3 | pandas==2.2.2 4 | numpy==1.24.3 5 | plotly==5.17.0 6 | torch>=2.1.0 7 | huggingface_hub==0.33.1 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | torch 4 | 5 | einops==0.8.1 6 | huggingface_hub==0.33.1 7 | matplotlib==3.9.3 8 | pandas==2.2.2 9 | tqdm==4.67.1 10 | safetensors==0.6.2 11 | -------------------------------------------------------------------------------- /finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_073929.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_073929.png -------------------------------------------------------------------------------- /finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_073944.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_073944.png -------------------------------------------------------------------------------- /finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_074012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_074012.png -------------------------------------------------------------------------------- /finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_074042.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_074042.png -------------------------------------------------------------------------------- /finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_074251.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_074251.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .kronos import KronosTokenizer, Kronos, KronosPredictor 2 | 3 | model_dict = { 4 | 'kronos_tokenizer': KronosTokenizer, 5 | 'kronos': Kronos, 6 | 'kronos_predictor': KronosPredictor 7 | } 8 | 9 | 10 | def get_model_class(model_name): 11 | if model_name in model_dict: 12 | return model_dict[model_name] 13 | else: 14 | print(f"Model {model_name} not found in model_dict") 15 | raise NotImplementedError 16 | 17 | 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | MANIFEST 23 | 24 | # Jupyter Notebook 25 | .ipynb_checkpoints 26 | 27 | # PyCharm 28 | .idea/ 29 | 30 | # VS Code 31 | .vscode/ 32 | 33 | # macOS 34 | .DS_Store 35 | .AppleDouble 36 | .LSOverride 37 | 38 | # Windows 39 | Thumbs.db 40 | ehthumbs.db 41 | Desktop.ini 42 | 43 | # Linux 44 | *~ 45 | 46 | # Data files (large files) 47 | *.feather 48 | *.parquet 49 | *.h5 50 | *.hdf5 51 | 52 | # Model files (large files) 53 | *.pth 54 | *.pt 55 | *.ckpt 56 | *.bin 57 | 58 | # Logs 59 | *.log 60 | logs/ 61 | 62 | # Environment 63 | .env 64 | .venv 65 | env/ 66 | venv/ 67 | ENV/ 68 | env.bak/ 69 | venv.bak/ 70 | 71 | # Temporary files 72 | *.tmp 73 | *.temp 74 | temp/ 75 | tmp/ 76 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 ShiYu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /webui/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Kronos Web UI startup script 4 | 5 | echo "🚀 Starting Kronos Web UI..." 6 | echo "================================" 7 | 8 | # Check if Python is installed 9 | if ! command -v python3 &> /dev/null; then 10 | echo "❌ Python3 not installed, please install Python3 first" 11 | exit 1 12 | fi 13 | 14 | # Check if in correct directory 15 | if [ ! -f "app.py" ]; then 16 | echo "❌ Please run this script in the webui directory" 17 | exit 1 18 | fi 19 | 20 | # Check dependencies 21 | echo "📦 Checking dependencies..." 22 | if ! python3 -c "import flask, flask_cors, pandas, numpy, plotly" &> /dev/null; then 23 | echo "⚠️ Missing dependencies, installing..." 24 | pip3 install -r requirements.txt 25 | if [ $? -ne 0 ]; then 26 | echo "❌ Dependencies installation failed" 27 | exit 1 28 | fi 29 | echo "✅ Dependencies installation completed" 30 | else 31 | echo "✅ All dependencies installed" 32 | fi 33 | 34 | # Start application 35 | echo "🌐 Starting Web server..." 36 | echo "Access URL: http://localhost:7070" 37 | echo "Press Ctrl+C to stop server" 38 | echo "" 39 | 40 | python3 app.py 41 | -------------------------------------------------------------------------------- /examples/prediction_wo_vol_example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import sys 4 | sys.path.append("../") 5 | from model import Kronos, KronosTokenizer, KronosPredictor 6 | 7 | 8 | def plot_prediction(kline_df, pred_df): 9 | pred_df.index = kline_df.index[-pred_df.shape[0]:] 10 | sr_close = kline_df['close'] 11 | sr_pred_close = pred_df['close'] 12 | sr_close.name = 'Ground Truth' 13 | sr_pred_close.name = "Prediction" 14 | 15 | close_df = pd.concat([sr_close, sr_pred_close], axis=1) 16 | 17 | fig, ax = plt.subplots(1, 1, figsize=(8, 4)) 18 | 19 | ax.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 20 | ax.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 21 | ax.set_ylabel('Close Price', fontsize=14) 22 | ax.legend(loc='lower left', fontsize=12) 23 | ax.grid(True) 24 | 25 | plt.tight_layout() 26 | plt.show() 27 | 28 | 29 | # 1. Load Model and Tokenizer 30 | tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") 31 | model = Kronos.from_pretrained("NeoQuasar/Kronos-small") 32 | 33 | # 2. Instantiate Predictor 34 | predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) 35 | 36 | # 3. Prepare Data 37 | df = pd.read_csv("./data/XSHG_5min_600977.csv") 38 | df['timestamps'] = pd.to_datetime(df['timestamps']) 39 | 40 | lookback = 400 41 | pred_len = 120 42 | 43 | x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close']] 44 | x_timestamp = df.loc[:lookback-1, 'timestamps'] 45 | y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps'] 46 | 47 | # 4. Make Prediction 48 | pred_df = predictor.predict( 49 | df=x_df, 50 | x_timestamp=x_timestamp, 51 | y_timestamp=y_timestamp, 52 | pred_len=pred_len, 53 | T=1.0, 54 | top_p=0.9, 55 | sample_count=1, 56 | verbose=True 57 | ) 58 | 59 | # 5. Visualize Results 60 | print("Forecasted Data Head:") 61 | print(pred_df.head()) 62 | 63 | # Combine historical and forecasted data for plotting 64 | kline_df = df.loc[:lookback+pred_len-1] 65 | 66 | # visualize 67 | plot_prediction(kline_df, pred_df) 68 | 69 | -------------------------------------------------------------------------------- /finetune_csv/configs/config_ali09988_candle-5min.yaml: -------------------------------------------------------------------------------- 1 | #This is a template config for custom finetuning kronos on csv data 2 | #这是一份模板config,用于kronos的csv自定义数据微调 3 | 4 | data: 5 | data_path: "/xxxx/Kronos/finetune_csv/data/HK_ali_09988_kline_5min_all.csv" 6 | lookback_window: 512 7 | predict_window: 48 8 | max_context: 512 9 | clip: 5.0 10 | # dataset split ratio 11 | train_ratio: 0.9 12 | val_ratio: 0.1 13 | test_ratio: 0.0 14 | 15 | training: 16 | # control the training epochs of tokenizer and basemodel 17 | tokenizer_epochs: 30 18 | basemodel_epochs: 20 19 | batch_size: 32 20 | log_interval: 50 21 | num_workers: 6 22 | seed: 42 23 | 24 | tokenizer_learning_rate: 0.0002 25 | predictor_learning_rate: 0.000001 26 | 27 | adam_beta1: 0.9 28 | adam_beta2: 0.95 29 | adam_weight_decay: 0.1 30 | 31 | # gradient accumulation steps for tokenizer training 32 | accumulation_steps: 1 33 | 34 | # model path configuration 35 | model_paths: 36 | # pretrained model path 37 | pretrained_tokenizer: "/xxx/Kronos/pretrained/Kronos-Tokenizer-base" 38 | pretrained_predictor: "/xxx/Kronos/pretrained/Kronos-base" 39 | 40 | # experiment name - other paths will be generated based on this 41 | exp_name: "HK_ali_09988_kline_5min_all" 42 | base_path: "/xxx/Kronos/finetune_csv/finetuned/" 43 | 44 | # the following paths will be generated based on exp_name, no need to modify manually 45 | # way 1: leave empty string, the system will generate the full path 46 | base_save_path: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name} 47 | finetuned_tokenizer: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model 48 | 49 | # way 2: use template string, {exp_name} will be replaced with the actual experiment name 50 | # base_save_path: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}" 51 | # finetuned_tokenizer: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model" 52 | 53 | tokenizer_save_name: "tokenizer" 54 | basemodel_save_name: "basemodel" 55 | 56 | experiment: 57 | name: "kronos_custom_finetune" 58 | description: "Custom finetune for HK stock data" 59 | use_comet: false 60 | 61 | # control the training phase 62 | train_tokenizer: true 63 | train_basemodel: true 64 | 65 | # if true, skip the existing model training 66 | skip_existing: false 67 | 68 | # device configuration 69 | device: 70 | use_cuda: true 71 | device_id: 0 72 | 73 | -------------------------------------------------------------------------------- /examples/prediction_batch_example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import sys 4 | sys.path.append("../") 5 | from model import Kronos, KronosTokenizer, KronosPredictor 6 | 7 | 8 | def plot_prediction(kline_df, pred_df): 9 | pred_df.index = kline_df.index[-pred_df.shape[0]:] 10 | sr_close = kline_df['close'] 11 | sr_pred_close = pred_df['close'] 12 | sr_close.name = 'Ground Truth' 13 | sr_pred_close.name = "Prediction" 14 | 15 | sr_volume = kline_df['volume'] 16 | sr_pred_volume = pred_df['volume'] 17 | sr_volume.name = 'Ground Truth' 18 | sr_pred_volume.name = "Prediction" 19 | 20 | close_df = pd.concat([sr_close, sr_pred_close], axis=1) 21 | volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1) 22 | 23 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True) 24 | 25 | ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 26 | ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 27 | ax1.set_ylabel('Close Price', fontsize=14) 28 | ax1.legend(loc='lower left', fontsize=12) 29 | ax1.grid(True) 30 | 31 | ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 32 | ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 33 | ax2.set_ylabel('Volume', fontsize=14) 34 | ax2.legend(loc='upper left', fontsize=12) 35 | ax2.grid(True) 36 | 37 | plt.tight_layout() 38 | plt.show() 39 | 40 | 41 | # 1. Load Model and Tokenizer 42 | tokenizer = KronosTokenizer.from_pretrained('/home/csc/huggingface/Kronos-Tokenizer-base/') 43 | model = Kronos.from_pretrained("/home/csc/huggingface/Kronos-base/") 44 | 45 | # 2. Instantiate Predictor 46 | predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) 47 | 48 | # 3. Prepare Data 49 | df = pd.read_csv("./data/XSHG_5min_600977.csv") 50 | df['timestamps'] = pd.to_datetime(df['timestamps']) 51 | 52 | lookback = 400 53 | pred_len = 120 54 | 55 | dfs = [] 56 | xtsp = [] 57 | ytsp = [] 58 | for i in range(5): 59 | idf = df.loc[(i*400):(i*400+lookback-1), ['open', 'high', 'low', 'close', 'volume', 'amount']] 60 | i_x_timestamp = df.loc[(i*400):(i*400+lookback-1), 'timestamps'] 61 | i_y_timestamp = df.loc[(i*400+lookback):(i*400+lookback+pred_len-1), 'timestamps'] 62 | 63 | dfs.append(idf) 64 | xtsp.append(i_x_timestamp) 65 | ytsp.append(i_y_timestamp) 66 | 67 | pred_df = predictor.predict_batch( 68 | df_list=dfs, 69 | x_timestamp_list=xtsp, 70 | y_timestamp_list=ytsp, 71 | pred_len=pred_len, 72 | ) 73 | -------------------------------------------------------------------------------- /examples/prediction_example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import sys 4 | sys.path.append("../") 5 | from model import Kronos, KronosTokenizer, KronosPredictor 6 | 7 | 8 | def plot_prediction(kline_df, pred_df): 9 | pred_df.index = kline_df.index[-pred_df.shape[0]:] 10 | sr_close = kline_df['close'] 11 | sr_pred_close = pred_df['close'] 12 | sr_close.name = 'Ground Truth' 13 | sr_pred_close.name = "Prediction" 14 | 15 | sr_volume = kline_df['volume'] 16 | sr_pred_volume = pred_df['volume'] 17 | sr_volume.name = 'Ground Truth' 18 | sr_pred_volume.name = "Prediction" 19 | 20 | close_df = pd.concat([sr_close, sr_pred_close], axis=1) 21 | volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1) 22 | 23 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True) 24 | 25 | ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 26 | ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 27 | ax1.set_ylabel('Close Price', fontsize=14) 28 | ax1.legend(loc='lower left', fontsize=12) 29 | ax1.grid(True) 30 | 31 | ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 32 | ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 33 | ax2.set_ylabel('Volume', fontsize=14) 34 | ax2.legend(loc='upper left', fontsize=12) 35 | ax2.grid(True) 36 | 37 | plt.tight_layout() 38 | plt.show() 39 | 40 | 41 | # 1. Load Model and Tokenizer 42 | tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") 43 | model = Kronos.from_pretrained("NeoQuasar/Kronos-small") 44 | 45 | # 2. Instantiate Predictor 46 | predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) 47 | 48 | # 3. Prepare Data 49 | df = pd.read_csv("./data/XSHG_5min_600977.csv") 50 | df['timestamps'] = pd.to_datetime(df['timestamps']) 51 | 52 | lookback = 400 53 | pred_len = 120 54 | 55 | x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']] 56 | x_timestamp = df.loc[:lookback-1, 'timestamps'] 57 | y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps'] 58 | 59 | # 4. Make Prediction 60 | pred_df = predictor.predict( 61 | df=x_df, 62 | x_timestamp=x_timestamp, 63 | y_timestamp=y_timestamp, 64 | pred_len=pred_len, 65 | T=1.0, 66 | top_p=0.9, 67 | sample_count=1, 68 | verbose=True 69 | ) 70 | 71 | # 5. Visualize Results 72 | print("Forecasted Data Head:") 73 | print(pred_df.head()) 74 | 75 | # Combine historical and forecasted data for plotting 76 | kline_df = df.loc[:lookback+pred_len-1] 77 | 78 | # visualize 79 | plot_prediction(kline_df, pred_df) 80 | 81 | -------------------------------------------------------------------------------- /webui/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Kronos Web UI startup script 4 | """ 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import webbrowser 10 | import time 11 | 12 | def check_dependencies(): 13 | """Check if dependencies are installed""" 14 | try: 15 | import flask 16 | import flask_cors 17 | import pandas 18 | import numpy 19 | import plotly 20 | print("✅ All dependencies installed") 21 | return True 22 | except ImportError as e: 23 | print(f"❌ Missing dependency: {e}") 24 | print("Please run: pip install -r requirements.txt") 25 | return False 26 | 27 | def install_dependencies(): 28 | """Install dependencies""" 29 | print("Installing dependencies...") 30 | try: 31 | subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"]) 32 | print("✅ Dependencies installation completed") 33 | return True 34 | except subprocess.CalledProcessError: 35 | print("❌ Dependencies installation failed") 36 | return False 37 | 38 | def main(): 39 | """Main function""" 40 | print("🚀 Starting Kronos Web UI...") 41 | print("=" * 50) 42 | 43 | # Check dependencies 44 | if not check_dependencies(): 45 | print("\nAuto-install dependencies? (y/n): ", end="") 46 | if input().lower() == 'y': 47 | if not install_dependencies(): 48 | return 49 | else: 50 | print("Please manually install dependencies and retry") 51 | return 52 | 53 | # Check model availability 54 | try: 55 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 56 | from model import Kronos, KronosTokenizer, KronosPredictor 57 | print("✅ Kronos model library available") 58 | model_available = True 59 | except ImportError: 60 | print("⚠️ Kronos model library not available, will use simulated prediction") 61 | model_available = False 62 | 63 | # Start Flask application 64 | print("\n🌐 Starting Web server...") 65 | 66 | # Set environment variables 67 | os.environ['FLASK_APP'] = 'app.py' 68 | os.environ['FLASK_ENV'] = 'development' 69 | 70 | # Start server 71 | try: 72 | from app import app 73 | print("✅ Web server started successfully!") 74 | print(f"🌐 Access URL: http://localhost:7070") 75 | print("💡 Tip: Press Ctrl+C to stop server") 76 | 77 | # Auto-open browser 78 | time.sleep(2) 79 | webbrowser.open('http://localhost:7070') 80 | 81 | # Start Flask application 82 | app.run(debug=True, host='0.0.0.0', port=7070) 83 | 84 | except Exception as e: 85 | print(f"❌ Startup failed: {e}") 86 | print("Please check if port 7070 is occupied") 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /finetune_csv/README_CN.md: -------------------------------------------------------------------------------- 1 | # Kronos微调-支持自定义CSV数据集 2 | 3 | 这是一个在自定义的CSV格式数据上微调Kronos模型的完整流程。包含顺序训练(先训练tokenizer再训练predictor)和单独模块训练,同时支持分布式训练。 4 | 5 | 6 | ## 1. 准备数据 7 | 8 | ### 数据格式 9 | 10 | CSV文件必须包含以下列: 11 | - `timestamps`: 每个数据点的时间戳 12 | - `open`: 开盘价 13 | - `high`: 最高价 14 | - `low`: 最低价 15 | - `close`: 收盘价 16 | - `volume`: 交易量 17 | - `amount`: 交易金额 18 | 19 | (volume和amount可以全0如果没有这部分的数据) 20 | 21 | ### 示例数据格式 22 | 23 | | timestamps | open | close | high | low | volume | amount | 24 | |------------|------|-------|------|-----|--------|--------| 25 | | 2019/11/26 9:35 | 182.45215 | 184.45215 | 184.95215 | 182.45215 | 15136000 | 0 | 26 | | 2019/11/26 9:40 | 184.35215 | 183.85215 | 184.55215 | 183.45215 | 4433300 | 0 | 27 | | 2019/11/26 9:45 | 183.85215 | 183.35215 | 183.95215 | 182.95215 | 3070900 | 0 | 28 | 29 | > **标准数据样例**: `data/HK_ali_09988_kline_5min_all.csv` 30 | 31 | ## 2. 准备config文件 32 | 33 | data_path及预训练模型路径需要修改,训练参数可以自己调节 34 | 35 | ```yaml 36 | # 数据配置 37 | data: 38 | data_path: "/path/to/your/data.csv" 39 | lookback_window: 512 # 要使用的历史数据点 40 | predict_window: 48 # 要预测的未来点数 41 | max_context: 512 # 最大上下文长度 42 | 43 | ... 44 | 45 | ``` 46 | 这里还有其他一些设置, `configs/config_ali09988_candle-5min.yaml` 有更详细的注释。 47 | 48 | ## 3. 训练 49 | 50 | ### 方法1: 直接顺序训练 51 | 52 | `train_sequential.py` 脚本自动处理完整的训练流程: 53 | 54 | ```bash 55 | # 完整训练(tokenizer + predictor) 56 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml 57 | 58 | # 跳过已存在的模型 59 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing 60 | 61 | # 只训练tokenizer 62 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel 63 | 64 | # 只训练predictor 65 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer 66 | ``` 67 | 68 | ### 方法2: 单独组件训练 69 | 70 | 可以单独训练每个组件: 71 | 72 | ```bash 73 | # 步骤1: 训练tokenizer 74 | python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml 75 | 76 | # 步骤2: 训练predictor(需要微调后的tokenizer) 77 | python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml 78 | ``` 79 | 80 | ### DDP训练 81 | 82 | 如果有多卡,可以开启ddp加速训练: 83 | 84 | ```bash 85 | # 设置通信后端(NVIDIA GPU用nccl,CPU/混合用gloo) 86 | DIST_BACKEND=nccl \ 87 | torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml 88 | ``` 89 | 90 | ## 4. 训练结果 91 | 92 | 训练过程生成以下输出: 93 | 94 | ### 模型检查点 95 | - **Tokenizer**: 保存到 `{base_save_path}/{exp_name}/tokenizer/best_model/` 96 | - **Predictor**: 保存到 `{base_save_path}/{exp_name}/basemodel/best_model/` 97 | 98 | ### 训练日志 99 | - **控制台输出**: 实时训练进度和指标 100 | - **日志文件**: 详细日志保存到 `{base_save_path}/logs/` 101 | - **验证跟踪**: 基于验证损失保存最佳模型 102 | 103 | ## 5. 预测可视化 104 | 105 | 以下图像显示了kronos在阿里巴巴股票数据上微调后的示例训练结果: 106 | 107 | ![训练结果 1](examples/HK_ali_09988_kline_5min_all_historical_20250919_073929.png) 108 | 109 | ![训练结果 2](examples/HK_ali_09988_kline_5min_all_historical_20250919_073944.png) 110 | 111 | ![训练结果 3](examples/HK_ali_09988_kline_5min_all_historical_20250919_074012.png) 112 | 113 | ![训练结果 4](examples/HK_ali_09988_kline_5min_all_historical_20250919_074042.png) 114 | 115 | ![训练结果 5](examples/HK_ali_09988_kline_5min_all_historical_20250919_074251.png) 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /finetune/utils/training_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import datetime 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | def setup_ddp(): 10 | """ 11 | Initializes the distributed data parallel environment. 12 | 13 | This function relies on environment variables set by `torchrun` or a similar 14 | launcher. It initializes the process group and sets the CUDA device for the 15 | current process. 16 | 17 | Returns: 18 | tuple: A tuple containing (rank, world_size, local_rank). 19 | """ 20 | if not dist.is_available(): 21 | raise RuntimeError("torch.distributed is not available.") 22 | 23 | dist.init_process_group(backend="nccl") 24 | rank = int(os.environ["RANK"]) 25 | world_size = int(os.environ["WORLD_SIZE"]) 26 | local_rank = int(os.environ["LOCAL_RANK"]) 27 | torch.cuda.set_device(local_rank) 28 | print( 29 | f"[DDP Setup] Global Rank: {rank}/{world_size}, " 30 | f"Local Rank (GPU): {local_rank} on device {torch.cuda.current_device()}" 31 | ) 32 | return rank, world_size, local_rank 33 | 34 | 35 | def cleanup_ddp(): 36 | """Cleans up the distributed process group.""" 37 | if dist.is_initialized(): 38 | dist.destroy_process_group() 39 | 40 | 41 | def set_seed(seed: int, rank: int = 0): 42 | """ 43 | Sets the random seed for reproducibility across all relevant libraries. 44 | 45 | Args: 46 | seed (int): The base seed value. 47 | rank (int): The process rank, used to ensure different processes have 48 | different seeds, which can be important for data loading. 49 | """ 50 | actual_seed = seed + rank 51 | random.seed(actual_seed) 52 | np.random.seed(actual_seed) 53 | torch.manual_seed(actual_seed) 54 | if torch.cuda.is_available(): 55 | torch.cuda.manual_seed_all(actual_seed) 56 | # The two lines below can impact performance, so they are often 57 | # reserved for final experiments where reproducibility is critical. 58 | torch.backends.cudnn.deterministic = True 59 | torch.backends.cudnn.benchmark = False 60 | 61 | 62 | def get_model_size(model: torch.nn.Module) -> str: 63 | """ 64 | Calculates the number of trainable parameters in a PyTorch model and returns 65 | it as a human-readable string. 66 | 67 | Args: 68 | model (torch.nn.Module): The PyTorch model. 69 | 70 | Returns: 71 | str: A string representing the model size (e.g., "175.0B", "7.1M", "50.5K"). 72 | """ 73 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 74 | 75 | if total_params >= 1e9: 76 | return f"{total_params / 1e9:.1f}B" # Billions 77 | elif total_params >= 1e6: 78 | return f"{total_params / 1e6:.1f}M" # Millions 79 | else: 80 | return f"{total_params / 1e3:.1f}K" # Thousands 81 | 82 | 83 | def reduce_tensor(tensor: torch.Tensor, world_size: int, op=dist.ReduceOp.SUM) -> torch.Tensor: 84 | """ 85 | Reduces a tensor's value across all processes in a distributed setup. 86 | 87 | Args: 88 | tensor (torch.Tensor): The tensor to be reduced. 89 | world_size (int): The total number of processes. 90 | op (dist.ReduceOp, optional): The reduction operation (SUM, AVG, etc.). 91 | Defaults to dist.ReduceOp.SUM. 92 | 93 | Returns: 94 | torch.Tensor: The reduced tensor, which will be identical on all processes. 95 | """ 96 | rt = tensor.clone() 97 | dist.all_reduce(rt, op=op) 98 | # Note: `dist.ReduceOp.AVG` is available in newer torch versions. 99 | # For compatibility, manual division is sometimes used after a SUM. 100 | if op == dist.ReduceOp.AVG: 101 | rt /= world_size 102 | return rt 103 | 104 | 105 | def format_time(seconds: float) -> str: 106 | """ 107 | Formats a duration in seconds into a human-readable H:M:S string. 108 | 109 | Args: 110 | seconds (float): The total seconds. 111 | 112 | Returns: 113 | str: The formatted time string (e.g., "0:15:32"). 114 | """ 115 | return str(datetime.timedelta(seconds=int(seconds))) 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /finetune_csv/README.md: -------------------------------------------------------------------------------- 1 | # Kronos Fine-tuning on Custom CSV Datasets 2 | 3 | This module provides a comprehensive pipeline for fine-tuning Kronos models on your own CSV-formatted financial data. It supports both sequential training (tokenizer followed by predictor) and individual component training, with full distributed training capabilities. 4 | 5 | 6 | ## 1. Data Preparation 7 | 8 | ### Required Data Format 9 | 10 | Your CSV file must contain the following columns: 11 | - `timestamps`: DateTime stamps for each data point 12 | - `open`: Opening price 13 | - `high`: Highest price 14 | - `low`: Lowest price 15 | - `close`: Closing price 16 | - `volume`: Trading volume 17 | - `amount`: Trading amount 18 | 19 | (volume and amount can be 0 if not available) 20 | 21 | ### Sample Data Format 22 | 23 | | timestamps | open | close | high | low | volume | amount | 24 | |------------|------|-------|------|-----|--------|--------| 25 | | 2019/11/26 9:35 | 182.45215 | 184.45215 | 184.95215 | 182.45215 | 15136000 | 0 | 26 | | 2019/11/26 9:40 | 184.35215 | 183.85215 | 184.55215 | 183.45215 | 4433300 | 0 | 27 | | 2019/11/26 9:45 | 183.85215 | 183.35215 | 183.95215 | 182.95215 | 3070900 | 0 | 28 | 29 | > **Reference**: Check `data/HK_ali_09988_kline_5min_all.csv` for a complete example of the proper data format. 30 | 31 | 32 | ## 2. Config Preparation 33 | 34 | 35 | Please edit the correct data path & pretrained model path and set your training parameters. 36 | 37 | ```yaml 38 | # Data configuration 39 | data: 40 | data_path: "/path/to/your/data.csv" 41 | lookback_window: 512 # Historical data points to use 42 | predict_window: 48 # Future points to predict 43 | max_context: 512 # Maximum context length 44 | 45 | ... 46 | 47 | ``` 48 | There are some other settings here, please see `configs/config_ali09988_candle-5min.yaml` for more comments. 49 | 50 | ## 3. Training 51 | 52 | ### Method 1: Sequential Training (Recommended) 53 | 54 | The `train_sequential.py` script handles the complete training pipeline automatically: 55 | 56 | ```bash 57 | # Complete training (tokenizer + predictor) 58 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml 59 | 60 | # Skip existing models 61 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing 62 | 63 | # Only train tokenizer 64 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel 65 | 66 | # Only train predictor 67 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer 68 | ``` 69 | 70 | ### Method 2: Individual Component Training 71 | 72 | Train each component separately for more control: 73 | 74 | ```bash 75 | # Step 1: Train tokenizer 76 | python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml 77 | 78 | # Step 2: Train predictor (requires fine-tuned tokenizer) 79 | python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml 80 | ``` 81 | 82 | ### DDP Training 83 | 84 | For faster training on multiple GPUs: 85 | 86 | ```bash 87 | # Set communication backend (nccl for NVIDIA GPUs, gloo for CPU/mixed) 88 | DIST_BACKEND=nccl \ 89 | torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml 90 | ``` 91 | 92 | ## 4. Training Results 93 | 94 | The training process generates several outputs: 95 | 96 | ### Model Checkpoints 97 | - **Tokenizer**: Saved to `{base_save_path}/{exp_name}/tokenizer/best_model/` 98 | - **Predictor**: Saved to `{base_save_path}/{exp_name}/basemodel/best_model/` 99 | 100 | ### Training Logs 101 | - **Console output**: Real-time training progress and metrics 102 | - **Log files**: Detailed logs saved to `{base_save_path}/logs/` 103 | - **Validation tracking**: Best models are saved based on validation loss 104 | 105 | ## 5. Prediction Vis 106 | 107 | The following images show example training results on alibaba (HK stock) data: 108 | 109 | ![Training Result 1](examples/HK_ali_09988_kline_5min_all_historical_20250919_073929.png) 110 | 111 | ![Training Result 2](examples/HK_ali_09988_kline_5min_all_historical_20250919_073944.png) 112 | 113 | ![Training Result 3](examples/HK_ali_09988_kline_5min_all_historical_20250919_074012.png) 114 | 115 | ![Training Result 4](examples/HK_ali_09988_kline_5min_all_historical_20250919_074042.png) 116 | 117 | ![Training Result 5](examples/HK_ali_09988_kline_5min_all_historical_20250919_074251.png) 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /webui/README.md: -------------------------------------------------------------------------------- 1 | # Kronos Web UI 2 | 3 | Web user interface for Kronos financial prediction model, providing intuitive graphical operation interface. 4 | 5 | ## ✨ Features 6 | 7 | - **Multi-format data support**: Supports CSV, Feather and other financial data formats 8 | - **Smart time window**: Fixed 400+120 data point time window slider selection 9 | - **Real model prediction**: Integrated real Kronos model, supports multiple model sizes 10 | - **Prediction quality control**: Adjustable temperature, nucleus sampling, sample count and other parameters 11 | - **Multi-device support**: Supports CPU, CUDA, MPS and other computing devices 12 | - **Comparison analysis**: Detailed comparison between prediction results and actual data 13 | - **K-line chart display**: Professional financial K-line chart display 14 | 15 | ## 🚀 Quick Start 16 | 17 | ### Method 1: Start with Python script 18 | ```bash 19 | cd webui 20 | python run.py 21 | ``` 22 | 23 | ### Method 2: Start with Shell script 24 | ```bash 25 | cd webui 26 | chmod +x start.sh 27 | ./start.sh 28 | ``` 29 | 30 | ### Method 3: Start Flask application directly 31 | ```bash 32 | cd webui 33 | python app.py 34 | ``` 35 | 36 | After successful startup, visit http://localhost:7070 37 | 38 | ## 📋 Usage Steps 39 | 40 | 1. **Load data**: Select financial data file from data directory 41 | 2. **Load model**: Select Kronos model and computing device 42 | 3. **Set parameters**: Adjust prediction quality parameters 43 | 4. **Select time window**: Use slider to select 400+120 data point time range 44 | 5. **Start prediction**: Click prediction button to generate results 45 | 6. **View results**: View prediction results in charts and tables 46 | 47 | ## 🔧 Prediction Quality Parameters 48 | 49 | ### Temperature (T) 50 | - **Range**: 0.1 - 2.0 51 | - **Effect**: Controls prediction randomness 52 | - **Recommendation**: 1.2-1.5 for better prediction quality 53 | 54 | ### Nucleus Sampling (top_p) 55 | - **Range**: 0.1 - 1.0 56 | - **Effect**: Controls prediction diversity 57 | - **Recommendation**: 0.95-1.0 to consider more possibilities 58 | 59 | ### Sample Count 60 | - **Range**: 1 - 5 61 | - **Effect**: Generate multiple prediction samples 62 | - **Recommendation**: 2-3 samples to improve quality 63 | 64 | ## 📊 Supported Data Formats 65 | 66 | ### Required Columns 67 | - `open`: Opening price 68 | - `high`: Highest price 69 | - `low`: Lowest price 70 | - `close`: Closing price 71 | 72 | ### Optional Columns 73 | - `volume`: Trading volume 74 | - `amount`: Trading amount (not used for prediction) 75 | - `timestamps`/`timestamp`/`date`: Timestamp 76 | 77 | ## 🤖 Model Support 78 | 79 | - **Kronos-mini**: 4.1M parameters, lightweight fast prediction 80 | - **Kronos-small**: 24.7M parameters, balanced performance and speed 81 | - **Kronos-base**: 102.3M parameters, high quality prediction 82 | 83 | ## 🖥️ GPU Acceleration Support 84 | 85 | - **CPU**: General computing, best compatibility 86 | - **CUDA**: NVIDIA GPU acceleration, best performance 87 | - **MPS**: Apple Silicon GPU acceleration, recommended for Mac users 88 | 89 | ## ⚠️ Notes 90 | 91 | - `amount` column is not used for prediction, only for display 92 | - Time window is fixed at 400+120=520 data points 93 | - Ensure data file contains sufficient historical data 94 | - First model loading may require download, please be patient 95 | 96 | ## 🔍 Comparison Analysis 97 | 98 | The system automatically provides comparison analysis between prediction results and actual data, including: 99 | - Price difference statistics 100 | - Error analysis 101 | - Prediction quality assessment 102 | 103 | ## 🛠️ Technical Architecture 104 | 105 | - **Backend**: Flask + Python 106 | - **Frontend**: HTML + CSS + JavaScript 107 | - **Charts**: Plotly.js 108 | - **Data processing**: Pandas + NumPy 109 | - **Model**: Hugging Face Transformers 110 | 111 | ## 📝 Troubleshooting 112 | 113 | ### Common Issues 114 | 1. **Port occupied**: Modify port number in app.py 115 | 2. **Missing dependencies**: Run `pip install -r requirements.txt` 116 | 3. **Model loading failed**: Check network connection and model ID 117 | 4. **Data format error**: Ensure data column names and format are correct 118 | 119 | ### Log Viewing 120 | Detailed runtime information will be displayed in the console at startup, including model status and error messages. 121 | 122 | ## 📄 License 123 | 124 | This project follows the license terms of the original Kronos project. 125 | 126 | ## 🤝 Contributing 127 | 128 | Welcome to submit Issues and Pull Requests to improve this Web UI! 129 | 130 | ## 📞 Support 131 | 132 | If you have questions, please check: 133 | 1. Project documentation 134 | 2. GitHub Issues 135 | 3. Console error messages 136 | -------------------------------------------------------------------------------- /finetune/qlib_data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import pandas as pd 5 | import qlib 6 | from qlib.config import REG_CN 7 | from qlib.data import D 8 | from qlib.data.dataset.loader import QlibDataLoader 9 | from tqdm import trange 10 | 11 | from config import Config 12 | 13 | 14 | class QlibDataPreprocessor: 15 | """ 16 | A class to handle the loading, processing, and splitting of Qlib financial data. 17 | """ 18 | 19 | def __init__(self): 20 | """Initializes the preprocessor with configuration and data fields.""" 21 | self.config = Config() 22 | self.data_fields = ['open', 'close', 'high', 'low', 'volume', 'vwap'] 23 | self.data = {} # A dictionary to store processed data for each symbol. 24 | 25 | def initialize_qlib(self): 26 | """Initializes the Qlib environment.""" 27 | print("Initializing Qlib...") 28 | qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN) 29 | 30 | def load_qlib_data(self): 31 | """ 32 | Loads raw data from Qlib, processes it symbol by symbol, and stores 33 | it in the `self.data` attribute. 34 | """ 35 | print("Loading and processing data from Qlib...") 36 | data_fields_qlib = ['$' + f for f in self.data_fields] 37 | cal: np.ndarray = D.calendar() 38 | 39 | # Determine the actual start and end times to load, including buffer for lookback and predict windows. 40 | start_index = cal.searchsorted(pd.Timestamp(self.config.dataset_begin_time)) 41 | end_index = cal.searchsorted(pd.Timestamp(self.config.dataset_end_time)) 42 | 43 | # Check if start_index lookbackw_window will cause negative index 44 | adjusted_start_index = max(start_index - self.config.lookback_window, 0) 45 | real_start_time = cal[adjusted_start_index] 46 | 47 | # Check if end_index exceeds the range of the array 48 | if end_index >= len(cal): 49 | end_index = len(cal) - 1 50 | elif cal[end_index] != pd.Timestamp(self.config.dataset_end_time): 51 | end_index -= 1 52 | 53 | # Check if end_index+predictw_window will exceed the range of the array 54 | adjusted_end_index = min(end_index + self.config.predict_window, len(cal) - 1) 55 | real_end_time = cal[adjusted_end_index] 56 | 57 | # Load data using Qlib's data loader. 58 | data_df = QlibDataLoader(config=data_fields_qlib).load( 59 | self.config.instrument, real_start_time, real_end_time 60 | ) 61 | data_df = data_df.stack().unstack(level=1) # Reshape for easier access. 62 | 63 | symbol_list = list(data_df.columns) 64 | for i in trange(len(symbol_list), desc="Processing Symbols"): 65 | symbol = symbol_list[i] 66 | symbol_df = data_df[symbol] 67 | 68 | # Pivot the table to have features as columns and datetime as index. 69 | symbol_df = symbol_df.reset_index().rename(columns={'level_1': 'field'}) 70 | symbol_df = pd.pivot(symbol_df, index='datetime', columns='field', values=symbol) 71 | symbol_df = symbol_df.rename(columns={f'${field}': field for field in self.data_fields}) 72 | 73 | # Calculate amount and select final features. 74 | symbol_df['vol'] = symbol_df['volume'] 75 | symbol_df['amt'] = (symbol_df['open'] + symbol_df['high'] + symbol_df['low'] + symbol_df['close']) / 4 * symbol_df['vol'] 76 | symbol_df = symbol_df[self.config.feature_list] 77 | 78 | # Filter out symbols with insufficient data. 79 | symbol_df = symbol_df.dropna() 80 | if len(symbol_df) < self.config.lookback_window + self.config.predict_window + 1: 81 | continue 82 | 83 | self.data[symbol] = symbol_df 84 | 85 | def prepare_dataset(self): 86 | """ 87 | Splits the loaded data into train, validation, and test sets and saves them to disk. 88 | """ 89 | print("Splitting data into train, validation, and test sets...") 90 | train_data, val_data, test_data = {}, {}, {} 91 | 92 | symbol_list = list(self.data.keys()) 93 | for i in trange(len(symbol_list), desc="Preparing Datasets"): 94 | symbol = symbol_list[i] 95 | symbol_df = self.data[symbol] 96 | 97 | # Define time ranges from config. 98 | train_start, train_end = self.config.train_time_range 99 | val_start, val_end = self.config.val_time_range 100 | test_start, test_end = self.config.test_time_range 101 | 102 | # Create boolean masks for each dataset split. 103 | train_mask = (symbol_df.index >= train_start) & (symbol_df.index <= train_end) 104 | val_mask = (symbol_df.index >= val_start) & (symbol_df.index <= val_end) 105 | test_mask = (symbol_df.index >= test_start) & (symbol_df.index <= test_end) 106 | 107 | # Apply masks to create the final datasets. 108 | train_data[symbol] = symbol_df[train_mask] 109 | val_data[symbol] = symbol_df[val_mask] 110 | test_data[symbol] = symbol_df[test_mask] 111 | 112 | # Save the datasets using pickle. 113 | os.makedirs(self.config.dataset_path, exist_ok=True) 114 | with open(f"{self.config.dataset_path}/train_data.pkl", 'wb') as f: 115 | pickle.dump(train_data, f) 116 | with open(f"{self.config.dataset_path}/val_data.pkl", 'wb') as f: 117 | pickle.dump(val_data, f) 118 | with open(f"{self.config.dataset_path}/test_data.pkl", 'wb') as f: 119 | pickle.dump(test_data, f) 120 | 121 | print("Datasets prepared and saved successfully.") 122 | 123 | 124 | if __name__ == '__main__': 125 | # This block allows the script to be run directly to perform data preprocessing. 126 | preprocessor = QlibDataPreprocessor() 127 | preprocessor.initialize_qlib() 128 | preprocessor.load_qlib_data() 129 | preprocessor.prepare_dataset() 130 | 131 | -------------------------------------------------------------------------------- /finetune/dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from config import Config 7 | 8 | 9 | class QlibDataset(Dataset): 10 | """ 11 | A PyTorch Dataset for handling Qlib financial time series data. 12 | 13 | This dataset pre-computes all possible start indices for sliding windows 14 | and then randomly samples from them during training/validation. 15 | 16 | Args: 17 | data_type (str): The type of dataset to load, either 'train' or 'val'. 18 | 19 | Raises: 20 | ValueError: If `data_type` is not 'train' or 'val'. 21 | """ 22 | 23 | def __init__(self, data_type: str = 'train'): 24 | self.config = Config() 25 | if data_type not in ['train', 'val']: 26 | raise ValueError("data_type must be 'train' or 'val'") 27 | self.data_type = data_type 28 | 29 | # Use a dedicated random number generator for sampling to avoid 30 | # interfering with other random processes (e.g., in model initialization). 31 | self.py_rng = random.Random(self.config.seed) 32 | 33 | # Set paths and number of samples based on the data type. 34 | if data_type == 'train': 35 | self.data_path = f"{self.config.dataset_path}/train_data.pkl" 36 | self.n_samples = self.config.n_train_iter 37 | else: 38 | self.data_path = f"{self.config.dataset_path}/val_data.pkl" 39 | self.n_samples = self.config.n_val_iter 40 | 41 | with open(self.data_path, 'rb') as f: 42 | self.data = pickle.load(f) 43 | 44 | self.window = self.config.lookback_window + self.config.predict_window + 1 45 | 46 | self.symbols = list(self.data.keys()) 47 | self.feature_list = self.config.feature_list 48 | self.time_feature_list = self.config.time_feature_list 49 | 50 | # Pre-compute all possible (symbol, start_index) pairs. 51 | self.indices = [] 52 | print(f"[{data_type.upper()}] Pre-computing sample indices...") 53 | for symbol in self.symbols: 54 | df = self.data[symbol].reset_index() 55 | series_len = len(df) 56 | num_samples = series_len - self.window + 1 57 | 58 | if num_samples > 0: 59 | # Generate time features and store them directly in the dataframe. 60 | df['minute'] = df['datetime'].dt.minute 61 | df['hour'] = df['datetime'].dt.hour 62 | df['weekday'] = df['datetime'].dt.weekday 63 | df['day'] = df['datetime'].dt.day 64 | df['month'] = df['datetime'].dt.month 65 | # Keep only necessary columns to save memory. 66 | self.data[symbol] = df[self.feature_list + self.time_feature_list] 67 | 68 | # Add all valid starting indices for this symbol to the global list. 69 | for i in range(num_samples): 70 | self.indices.append((symbol, i)) 71 | 72 | # The effective dataset size is the minimum of the configured iterations 73 | # and the total number of available samples. 74 | self.n_samples = min(self.n_samples, len(self.indices)) 75 | print(f"[{data_type.upper()}] Found {len(self.indices)} possible samples. Using {self.n_samples} per epoch.") 76 | 77 | def set_epoch_seed(self, epoch: int): 78 | """ 79 | Sets a new seed for the random sampler for each epoch. This is crucial 80 | for reproducibility in distributed training. 81 | 82 | Args: 83 | epoch (int): The current epoch number. 84 | """ 85 | epoch_seed = self.config.seed + epoch 86 | self.py_rng.seed(epoch_seed) 87 | 88 | def __len__(self) -> int: 89 | """Returns the number of samples per epoch.""" 90 | return self.n_samples 91 | 92 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: 93 | """ 94 | Retrieves a random sample from the dataset. 95 | 96 | Note: The `idx` argument is ignored. Instead, a random index is drawn 97 | from the pre-computed `self.indices` list using `self.py_rng`. This 98 | ensures random sampling over the entire dataset for each call. 99 | 100 | Args: 101 | idx (int): Ignored. 102 | 103 | Returns: 104 | tuple[torch.Tensor, torch.Tensor]: A tuple containing: 105 | - x_tensor (torch.Tensor): The normalized feature tensor. 106 | - x_stamp_tensor (torch.Tensor): The time feature tensor. 107 | """ 108 | # Select a random sample from the entire pool of indices. 109 | random_idx = self.py_rng.randint(0, len(self.indices) - 1) 110 | symbol, start_idx = self.indices[random_idx] 111 | 112 | # Extract the sliding window from the dataframe. 113 | df = self.data[symbol] 114 | end_idx = start_idx + self.window 115 | win_df = df.iloc[start_idx:end_idx] 116 | 117 | # Separate main features and time features. 118 | x = win_df[self.feature_list].values.astype(np.float32) 119 | x_stamp = win_df[self.time_feature_list].values.astype(np.float32) 120 | 121 | # Perform instance-level normalization. 122 | x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) 123 | x = (x - x_mean) / (x_std + 1e-5) 124 | x = np.clip(x, -self.config.clip, self.config.clip) 125 | 126 | # Convert to PyTorch tensors. 127 | x_tensor = torch.from_numpy(x) 128 | x_stamp_tensor = torch.from_numpy(x_stamp) 129 | 130 | return x_tensor, x_stamp_tensor 131 | 132 | 133 | if __name__ == '__main__': 134 | # Example usage and verification. 135 | print("Creating training dataset instance...") 136 | train_dataset = QlibDataset(data_type='train') 137 | 138 | print(f"Dataset length: {len(train_dataset)}") 139 | 140 | if len(train_dataset) > 0: 141 | try_x, try_x_stamp = train_dataset[100] # Index 100 is ignored. 142 | print(f"Sample feature shape: {try_x.shape}") 143 | print(f"Sample time feature shape: {try_x_stamp.shape}") 144 | else: 145 | print("Dataset is empty.") 146 | -------------------------------------------------------------------------------- /finetune/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Config: 4 | """ 5 | Configuration class for the entire project. 6 | """ 7 | 8 | def __init__(self): 9 | # ================================================================= 10 | # Data & Feature Parameters 11 | # ================================================================= 12 | # TODO: Update this path to your Qlib data directory. 13 | self.qlib_data_path = "~/.qlib/qlib_data/cn_data" 14 | self.instrument = 'csi300' 15 | 16 | # Overall time range for data loading from Qlib. 17 | self.dataset_begin_time = "2011-01-01" 18 | self.dataset_end_time = '2025-06-05' 19 | 20 | # Sliding window parameters for creating samples. 21 | self.lookback_window = 90 # Number of past time steps for input. 22 | self.predict_window = 10 # Number of future time steps for prediction. 23 | self.max_context = 512 # Maximum context length for the model. 24 | 25 | # Features to be used from the raw data. 26 | self.feature_list = ['open', 'high', 'low', 'close', 'vol', 'amt'] 27 | # Time-based features to be generated. 28 | self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month'] 29 | 30 | # ================================================================= 31 | # Dataset Splitting & Paths 32 | # ================================================================= 33 | # Note: The validation/test set starts earlier than the training/validation set ends 34 | # to account for the `lookback_window`. 35 | self.train_time_range = ["2011-01-01", "2022-12-31"] 36 | self.val_time_range = ["2022-09-01", "2024-06-30"] 37 | self.test_time_range = ["2024-04-01", "2025-06-05"] 38 | self.backtest_time_range = ["2024-07-01", "2025-06-05"] 39 | 40 | # TODO: Directory to save the processed, pickled datasets. 41 | self.dataset_path = "./data/processed_datasets" 42 | 43 | # ================================================================= 44 | # Training Hyperparameters 45 | # ================================================================= 46 | self.clip = 5.0 # Clipping value for normalized data to prevent outliers. 47 | 48 | self.epochs = 30 49 | self.log_interval = 100 # Log training status every N batches. 50 | self.batch_size = 50 # Batch size per GPU. 51 | 52 | # Number of samples to draw for one "epoch" of training/validation. 53 | # This is useful for large datasets where a true epoch is too long. 54 | self.n_train_iter = 2000 * self.batch_size 55 | self.n_val_iter = 400 * self.batch_size 56 | 57 | # Learning rates for different model components. 58 | self.tokenizer_learning_rate = 2e-4 59 | self.predictor_learning_rate = 4e-5 60 | 61 | # Gradient accumulation to simulate a larger batch size. 62 | self.accumulation_steps = 1 63 | 64 | # AdamW optimizer parameters. 65 | self.adam_beta1 = 0.9 66 | self.adam_beta2 = 0.95 67 | self.adam_weight_decay = 0.1 68 | 69 | # Miscellaneous 70 | self.seed = 100 # Global random seed for reproducibility. 71 | 72 | # ================================================================= 73 | # Experiment Logging & Saving 74 | # ================================================================= 75 | self.use_comet = True # Set to False if you don't want to use Comet ML 76 | self.comet_config = { 77 | # It is highly recommended to load secrets from environment variables 78 | # for security purposes. Example: os.getenv("COMET_API_KEY") 79 | "api_key": "YOUR_COMET_API_KEY", 80 | "project_name": "Kronos-Finetune-Demo", 81 | "workspace": "your_comet_workspace" # TODO: Change to your Comet ML workspace name 82 | } 83 | self.comet_tag = 'finetune_demo' 84 | self.comet_name = 'finetune_demo' 85 | 86 | # Base directory for saving model checkpoints and results. 87 | # Using a general 'outputs' directory is a common practice. 88 | self.save_path = "./outputs/models" 89 | self.tokenizer_save_folder_name = 'finetune_tokenizer_demo' 90 | self.predictor_save_folder_name = 'finetune_predictor_demo' 91 | self.backtest_save_folder_name = 'finetune_backtest_demo' 92 | 93 | # Path for backtesting results. 94 | self.backtest_result_path = "./outputs/backtest_results" 95 | 96 | # ================================================================= 97 | # Model & Checkpoint Paths 98 | # ================================================================= 99 | # TODO: Update these paths to your pretrained model locations. 100 | # These can be local paths or Hugging Face Hub model identifiers. 101 | self.pretrained_tokenizer_path = "path/to/your/Kronos-Tokenizer-base" 102 | self.pretrained_predictor_path = "path/to/your/Kronos-small" 103 | 104 | # Paths to the fine-tuned models, derived from the save_path. 105 | # These will be generated automatically during training. 106 | self.finetuned_tokenizer_path = f"{self.save_path}/{self.tokenizer_save_folder_name}/checkpoints/best_model" 107 | self.finetuned_predictor_path = f"{self.save_path}/{self.predictor_save_folder_name}/checkpoints/best_model" 108 | 109 | # ================================================================= 110 | # Backtesting Parameters 111 | # ================================================================= 112 | self.backtest_n_symbol_hold = 50 # Number of symbols to hold in the portfolio. 113 | self.backtest_n_symbol_drop = 5 # Number of symbols to drop from the pool. 114 | self.backtest_hold_thresh = 5 # Minimum holding period for a stock. 115 | self.inference_T = 0.6 116 | self.inference_top_p = 0.9 117 | self.inference_top_k = 0 118 | self.inference_sample_count = 5 119 | self.backtest_batch_size = 1000 120 | self.backtest_benchmark = self._set_benchmark(self.instrument) 121 | 122 | def _set_benchmark(self, instrument): 123 | dt_benchmark = { 124 | 'csi800': "SH000906", 125 | 'csi1000': "SH000852", 126 | 'csi300': "SH000300", 127 | } 128 | if instrument in dt_benchmark: 129 | return dt_benchmark[instrument] 130 | else: 131 | raise ValueError(f"Benchmark not defined for instrument: {instrument}") 132 | -------------------------------------------------------------------------------- /finetune/train_predictor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | from time import gmtime, strftime 6 | import torch.distributed as dist 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | 12 | import comet_ml 13 | 14 | # Ensure project root is in path 15 | sys.path.append('../') 16 | from config import Config 17 | from dataset import QlibDataset 18 | from model.kronos import KronosTokenizer, Kronos 19 | # Import shared utilities 20 | from utils.training_utils import ( 21 | setup_ddp, 22 | cleanup_ddp, 23 | set_seed, 24 | get_model_size, 25 | format_time 26 | ) 27 | 28 | 29 | def create_dataloaders(config: dict, rank: int, world_size: int): 30 | """ 31 | Creates and returns distributed dataloaders for training and validation. 32 | 33 | Args: 34 | config (dict): A dictionary of configuration parameters. 35 | rank (int): The global rank of the current process. 36 | world_size (int): The total number of processes. 37 | 38 | Returns: 39 | tuple: (train_loader, val_loader, train_dataset, valid_dataset). 40 | """ 41 | print(f"[Rank {rank}] Creating distributed dataloaders...") 42 | train_dataset = QlibDataset('train') 43 | valid_dataset = QlibDataset('val') 44 | print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}") 45 | 46 | train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) 47 | val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False) 48 | 49 | train_loader = DataLoader( 50 | train_dataset, batch_size=config['batch_size'], sampler=train_sampler, 51 | num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=True 52 | ) 53 | val_loader = DataLoader( 54 | valid_dataset, batch_size=config['batch_size'], sampler=val_sampler, 55 | num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=False 56 | ) 57 | return train_loader, val_loader, train_dataset, valid_dataset 58 | 59 | 60 | def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_size): 61 | """ 62 | The main training and validation loop for the predictor. 63 | """ 64 | start_time = time.time() 65 | if rank == 0: 66 | effective_bs = config['batch_size'] * world_size 67 | print(f"Effective BATCHSIZE per GPU: {config['batch_size']}, Total: {effective_bs}") 68 | 69 | train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size) 70 | 71 | optimizer = torch.optim.AdamW( 72 | model.parameters(), 73 | lr=config['predictor_learning_rate'], 74 | betas=(config['adam_beta1'], config['adam_beta2']), 75 | weight_decay=config['adam_weight_decay'] 76 | ) 77 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 78 | optimizer, max_lr=config['predictor_learning_rate'], 79 | steps_per_epoch=len(train_loader), epochs=config['epochs'], 80 | pct_start=0.03, div_factor=10 81 | ) 82 | 83 | best_val_loss = float('inf') 84 | dt_result = {} 85 | batch_idx_global = 0 86 | 87 | for epoch_idx in range(config['epochs']): 88 | epoch_start_time = time.time() 89 | model.train() 90 | train_loader.sampler.set_epoch(epoch_idx) 91 | 92 | train_dataset.set_epoch_seed(epoch_idx * 10000 + rank) 93 | valid_dataset.set_epoch_seed(0) 94 | 95 | for i, (batch_x, batch_x_stamp) in enumerate(train_loader): 96 | batch_x = batch_x.squeeze(0).to(device, non_blocking=True) 97 | batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True) 98 | 99 | # Tokenize input data on-the-fly 100 | with torch.no_grad(): 101 | token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) 102 | 103 | # Prepare inputs and targets for the language model 104 | token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] 105 | token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] 106 | 107 | # Forward pass and loss calculation 108 | logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) 109 | loss, s1_loss, s2_loss = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) 110 | 111 | # Backward pass and optimization 112 | optimizer.zero_grad() 113 | loss.backward() 114 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0) 115 | optimizer.step() 116 | scheduler.step() 117 | 118 | # Logging (Master Process Only) 119 | if rank == 0 and (batch_idx_global + 1) % config['log_interval'] == 0: 120 | lr = optimizer.param_groups[0]['lr'] 121 | print( 122 | f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] " 123 | f"LR {lr:.6f}, Loss: {loss.item():.4f}" 124 | ) 125 | if rank == 0 and logger: 126 | lr = optimizer.param_groups[0]['lr'] 127 | logger.log_metric('train_predictor_loss_batch', loss.item(), step=batch_idx_global) 128 | logger.log_metric('train_S1_loss_each_batch', s1_loss.item(), step=batch_idx_global) 129 | logger.log_metric('train_S2_loss_each_batch', s2_loss.item(), step=batch_idx_global) 130 | logger.log_metric('predictor_learning_rate', lr, step=batch_idx_global) 131 | 132 | batch_idx_global += 1 133 | 134 | # --- Validation Loop --- 135 | model.eval() 136 | tot_val_loss_sum_rank = 0.0 137 | val_batches_processed_rank = 0 138 | with torch.no_grad(): 139 | for batch_x, batch_x_stamp in val_loader: 140 | batch_x = batch_x.squeeze(0).to(device, non_blocking=True) 141 | batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True) 142 | 143 | token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) 144 | token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] 145 | token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] 146 | 147 | logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) 148 | val_loss, _, _ = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) 149 | 150 | tot_val_loss_sum_rank += val_loss.item() 151 | val_batches_processed_rank += 1 152 | 153 | # Reduce validation metrics 154 | val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device) 155 | val_batches_tensor = torch.tensor(val_batches_processed_rank, device=device) 156 | dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM) 157 | dist.all_reduce(val_batches_tensor, op=dist.ReduceOp.SUM) 158 | 159 | avg_val_loss = val_loss_sum_tensor.item() / val_batches_tensor.item() if val_batches_tensor.item() > 0 else 0 160 | 161 | # --- End of Epoch Summary & Checkpointing (Master Process Only) --- 162 | if rank == 0: 163 | print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---") 164 | print(f"Validation Loss: {avg_val_loss:.4f}") 165 | print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}") 166 | print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n") 167 | if logger: 168 | logger.log_metric('val_predictor_loss_epoch', avg_val_loss, epoch=epoch_idx) 169 | 170 | if avg_val_loss < best_val_loss: 171 | best_val_loss = avg_val_loss 172 | save_path = f"{save_dir}/checkpoints/best_model" 173 | model.module.save_pretrained(save_path) 174 | print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})") 175 | 176 | dist.barrier() 177 | 178 | dt_result['best_val_loss'] = best_val_loss 179 | return dt_result 180 | 181 | 182 | def main(config: dict): 183 | """Main function to orchestrate the DDP training process.""" 184 | rank, world_size, local_rank = setup_ddp() 185 | device = torch.device(f"cuda:{local_rank}") 186 | set_seed(config['seed'], rank) 187 | 188 | save_dir = os.path.join(config['save_path'], config['predictor_save_folder_name']) 189 | 190 | # Logger and summary setup (master process only) 191 | comet_logger, master_summary = None, {} 192 | if rank == 0: 193 | os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True) 194 | master_summary = { 195 | 'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()), 196 | 'save_directory': save_dir, 197 | 'world_size': world_size, 198 | } 199 | if config['use_comet']: 200 | comet_logger = comet_ml.Experiment( 201 | api_key=config['comet_config']['api_key'], 202 | project_name=config['comet_config']['project_name'], 203 | workspace=config['comet_config']['workspace'], 204 | ) 205 | comet_logger.add_tag(config['comet_tag']) 206 | comet_logger.set_name(config['comet_name']) 207 | comet_logger.log_parameters(config) 208 | print("Comet Logger Initialized.") 209 | 210 | dist.barrier() 211 | 212 | # Model Initialization 213 | tokenizer = KronosTokenizer.from_pretrained(config['finetuned_tokenizer_path']) 214 | tokenizer.eval().to(device) 215 | 216 | model = Kronos.from_pretrained(config['pretrained_predictor_path']) 217 | model.to(device) 218 | model = DDP(model, device_ids=[local_rank], find_unused_parameters=False) 219 | 220 | if rank == 0: 221 | print(f"Predictor Model Size: {get_model_size(model.module)}") 222 | 223 | # Start Training 224 | dt_result = train_model( 225 | model, tokenizer, device, config, save_dir, comet_logger, rank, world_size 226 | ) 227 | 228 | if rank == 0: 229 | master_summary['final_result'] = dt_result 230 | with open(os.path.join(save_dir, 'summary.json'), 'w') as f: 231 | json.dump(master_summary, f, indent=4) 232 | print('Training finished. Summary file saved.') 233 | if comet_logger: comet_logger.end() 234 | 235 | cleanup_ddp() 236 | 237 | 238 | if __name__ == '__main__': 239 | # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_predictor.py 240 | if "WORLD_SIZE" not in os.environ: 241 | raise RuntimeError("This script must be launched with `torchrun`.") 242 | 243 | config_instance = Config() 244 | main(config_instance.__dict__) 245 | -------------------------------------------------------------------------------- /finetune/train_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | from time import gmtime, strftime 6 | import argparse 7 | import datetime 8 | import torch.distributed as dist 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | 15 | import comet_ml 16 | 17 | # Ensure project root is in path 18 | sys.path.append("../") 19 | from config import Config 20 | from dataset import QlibDataset 21 | from model.kronos import KronosTokenizer 22 | # Import shared utilities 23 | from utils.training_utils import ( 24 | setup_ddp, 25 | cleanup_ddp, 26 | set_seed, 27 | get_model_size, 28 | format_time, 29 | ) 30 | 31 | 32 | def create_dataloaders(config: dict, rank: int, world_size: int): 33 | """ 34 | Creates and returns distributed dataloaders for training and validation. 35 | 36 | Args: 37 | config (dict): A dictionary of configuration parameters. 38 | rank (int): The global rank of the current process. 39 | world_size (int): The total number of processes. 40 | 41 | Returns: 42 | tuple: A tuple containing (train_loader, val_loader, train_dataset, valid_dataset). 43 | """ 44 | print(f"[Rank {rank}] Creating distributed dataloaders...") 45 | train_dataset = QlibDataset('train') 46 | valid_dataset = QlibDataset('val') 47 | print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}") 48 | 49 | train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) 50 | val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False) 51 | 52 | train_loader = DataLoader( 53 | train_dataset, 54 | batch_size=config['batch_size'], 55 | sampler=train_sampler, 56 | shuffle=False, # Shuffle is handled by the sampler 57 | num_workers=config.get('num_workers', 2), 58 | pin_memory=True, 59 | drop_last=True 60 | ) 61 | val_loader = DataLoader( 62 | valid_dataset, 63 | batch_size=config['batch_size'], 64 | sampler=val_sampler, 65 | shuffle=False, 66 | num_workers=config.get('num_workers', 2), 67 | pin_memory=True, 68 | drop_last=False 69 | ) 70 | print(f"[Rank {rank}] Dataloaders created. Train steps/epoch: {len(train_loader)}, Val steps: {len(val_loader)}") 71 | return train_loader, val_loader, train_dataset, valid_dataset 72 | 73 | 74 | def train_model(model, device, config, save_dir, logger, rank, world_size): 75 | """ 76 | The main training and validation loop for the tokenizer. 77 | 78 | Args: 79 | model (DDP): The DDP-wrapped model to train. 80 | device (torch.device): The device for the current process. 81 | config (dict): Configuration dictionary. 82 | save_dir (str): Directory to save checkpoints. 83 | logger (comet_ml.Experiment): Comet logger instance. 84 | rank (int): Global rank of the process. 85 | world_size (int): Total number of processes. 86 | 87 | Returns: 88 | tuple: A tuple containing the trained model and a dictionary of results. 89 | """ 90 | start_time = time.time() 91 | if rank == 0: 92 | effective_bs = config['batch_size'] * world_size * config['accumulation_steps'] 93 | print(f"[Rank {rank}] BATCHSIZE (per GPU): {config['batch_size']}") 94 | print(f"[Rank {rank}] Effective total batch size: {effective_bs}") 95 | 96 | train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size) 97 | 98 | optimizer = torch.optim.AdamW( 99 | model.parameters(), 100 | lr=config['tokenizer_learning_rate'], 101 | weight_decay=config['adam_weight_decay'] 102 | ) 103 | 104 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 105 | optimizer=optimizer, 106 | max_lr=config['tokenizer_learning_rate'], 107 | steps_per_epoch=len(train_loader), 108 | epochs=config['epochs'], 109 | pct_start=0.03, 110 | div_factor=10 111 | ) 112 | 113 | best_val_loss = float('inf') 114 | dt_result = {} 115 | batch_idx_global_train = 0 116 | 117 | for epoch_idx in range(config['epochs']): 118 | epoch_start_time = time.time() 119 | model.train() 120 | train_loader.sampler.set_epoch(epoch_idx) 121 | 122 | # Set dataset seeds for reproducible sampling 123 | train_dataset.set_epoch_seed(epoch_idx * 10000 + rank) 124 | valid_dataset.set_epoch_seed(0) # Keep validation sampling consistent 125 | 126 | for i, (ori_batch_x, _) in enumerate(train_loader): 127 | ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) 128 | 129 | # --- Gradient Accumulation Loop --- 130 | current_batch_total_loss = 0.0 131 | for j in range(config['accumulation_steps']): 132 | start_idx = j * (ori_batch_x.shape[0] // config['accumulation_steps']) 133 | end_idx = (j + 1) * (ori_batch_x.shape[0] // config['accumulation_steps']) 134 | batch_x = ori_batch_x[start_idx:end_idx] 135 | 136 | # Forward pass 137 | zs, bsq_loss, _, _ = model(batch_x) 138 | z_pre, z = zs 139 | 140 | # Loss calculation 141 | recon_loss_pre = F.mse_loss(z_pre, batch_x) 142 | recon_loss_all = F.mse_loss(z, batch_x) 143 | recon_loss = recon_loss_pre + recon_loss_all 144 | loss = (recon_loss + bsq_loss) / 2 # Assuming w_1=w_2=1 145 | 146 | loss_scaled = loss / config['accumulation_steps'] 147 | current_batch_total_loss += loss.item() 148 | loss_scaled.backward() 149 | 150 | # --- Optimizer Step after Accumulation --- 151 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0) 152 | optimizer.step() 153 | scheduler.step() 154 | optimizer.zero_grad() 155 | 156 | # --- Logging (Master Process Only) --- 157 | if rank == 0 and (batch_idx_global_train + 1) % config['log_interval'] == 0: 158 | avg_loss = current_batch_total_loss / config['accumulation_steps'] 159 | print( 160 | f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] " 161 | f"LR {optimizer.param_groups[0]['lr']:.6f}, Loss: {avg_loss:.4f}" 162 | ) 163 | if rank == 0 and logger: 164 | avg_loss = current_batch_total_loss / config['accumulation_steps'] 165 | logger.log_metric('train_tokenizer_loss_batch', avg_loss, step=batch_idx_global_train) 166 | logger.log_metric(f'train_vqvae_vq_loss_each_batch', bsq_loss.item(), step=batch_idx_global_train) 167 | logger.log_metric(f'train_recon_loss_pre_each_batch', recon_loss_pre.item(), step=batch_idx_global_train) 168 | logger.log_metric(f'train_recon_loss_each_batch', recon_loss_all.item(), step=batch_idx_global_train) 169 | logger.log_metric('tokenizer_learning_rate', optimizer.param_groups[0]["lr"], step=batch_idx_global_train) 170 | 171 | batch_idx_global_train += 1 172 | 173 | # --- Validation Loop --- 174 | model.eval() 175 | tot_val_loss_sum_rank = 0.0 176 | val_sample_count_rank = 0 177 | with torch.no_grad(): 178 | for ori_batch_x, _ in val_loader: 179 | ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) 180 | zs, _, _, _ = model(ori_batch_x) 181 | _, z = zs 182 | val_loss_item = F.mse_loss(z, ori_batch_x) 183 | 184 | tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0) 185 | val_sample_count_rank += ori_batch_x.size(0) 186 | 187 | # Reduce validation losses from all processes 188 | val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device) 189 | val_count_tensor = torch.tensor(val_sample_count_rank, device=device) 190 | dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM) 191 | dist.all_reduce(val_count_tensor, op=dist.ReduceOp.SUM) 192 | 193 | avg_val_loss = val_loss_sum_tensor.item() / val_count_tensor.item() if val_count_tensor.item() > 0 else 0 194 | 195 | # --- End of Epoch Summary & Checkpointing (Master Process Only) --- 196 | if rank == 0: 197 | print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---") 198 | print(f"Validation Loss: {avg_val_loss:.4f}") 199 | print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}") 200 | print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n") 201 | if logger: 202 | logger.log_metric('val_tokenizer_loss_epoch', avg_val_loss, epoch=epoch_idx) 203 | 204 | if avg_val_loss < best_val_loss: 205 | best_val_loss = avg_val_loss 206 | save_path = f"{save_dir}/checkpoints/best_model" 207 | model.module.save_pretrained(save_path) 208 | print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})") 209 | if logger: 210 | logger.log_model("best_model", save_path) 211 | 212 | dist.barrier() # Ensure all processes finish the epoch before starting the next one. 213 | 214 | dt_result['best_val_loss'] = best_val_loss 215 | return model, dt_result 216 | 217 | 218 | def main(config: dict): 219 | """ 220 | Main function to orchestrate the DDP training process. 221 | """ 222 | rank, world_size, local_rank = setup_ddp() 223 | device = torch.device(f"cuda:{local_rank}") 224 | set_seed(config['seed'], rank) 225 | 226 | save_dir = os.path.join(config['save_path'], config['tokenizer_save_folder_name']) 227 | 228 | # Logger and summary setup (master process only) 229 | comet_logger, master_summary = None, {} 230 | if rank == 0: 231 | os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True) 232 | master_summary = { 233 | 'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()), 234 | 'save_directory': save_dir, 235 | 'world_size': world_size, 236 | } 237 | if config['use_comet']: 238 | comet_logger = comet_ml.Experiment( 239 | api_key=config['comet_config']['api_key'], 240 | project_name=config['comet_config']['project_name'], 241 | workspace=config['comet_config']['workspace'], 242 | ) 243 | comet_logger.add_tag(config['comet_tag']) 244 | comet_logger.set_name(config['comet_name']) 245 | comet_logger.log_parameters(config) 246 | print("Comet Logger Initialized.") 247 | 248 | dist.barrier() # Ensure save directory is created before proceeding 249 | 250 | # Model Initialization 251 | model = KronosTokenizer.from_pretrained(config['pretrained_tokenizer_path']) 252 | model.to(device) 253 | model = DDP(model, device_ids=[local_rank], find_unused_parameters=False) 254 | 255 | if rank == 0: 256 | print(f"Model Size: {get_model_size(model.module)}") 257 | 258 | # Start Training 259 | _, dt_result = train_model( 260 | model, device, config, save_dir, comet_logger, rank, world_size 261 | ) 262 | 263 | # Finalize and save summary (master process only) 264 | if rank == 0: 265 | master_summary['final_result'] = dt_result 266 | with open(os.path.join(save_dir, 'summary.json'), 'w') as f: 267 | json.dump(master_summary, f, indent=4) 268 | print('Training finished. Summary file saved.') 269 | if comet_logger: 270 | comet_logger.end() 271 | 272 | cleanup_ddp() 273 | 274 | 275 | if __name__ == '__main__': 276 | # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_tokenizer.py 277 | if "WORLD_SIZE" not in os.environ: 278 | raise RuntimeError("This script must be launched with `torchrun`.") 279 | 280 | config_instance = Config() 281 | main(config_instance.__dict__) 282 | -------------------------------------------------------------------------------- /finetune_csv/config_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from typing import Dict, Any 4 | 5 | 6 | class ConfigLoader: 7 | 8 | def __init__(self, config_path: str): 9 | 10 | self.config_path = config_path 11 | self.config = self._load_config() 12 | 13 | def _load_config(self) -> Dict[str, Any]: 14 | 15 | if not os.path.exists(self.config_path): 16 | raise FileNotFoundError(f"config file not found: {self.config_path}") 17 | 18 | with open(self.config_path, 'r', encoding='utf-8') as f: 19 | config = yaml.safe_load(f) 20 | 21 | config = self._resolve_dynamic_paths(config) 22 | 23 | return config 24 | 25 | def _resolve_dynamic_paths(self, config: Dict[str, Any]) -> Dict[str, Any]: 26 | 27 | exp_name = config.get('model_paths', {}).get('exp_name', '') 28 | if not exp_name: 29 | return config 30 | 31 | base_path = config.get('model_paths', {}).get('base_path', '') 32 | path_templates = { 33 | 'base_save_path': f"{base_path}/{exp_name}", 34 | 'finetuned_tokenizer': f"{base_path}/{exp_name}/tokenizer/best_model" 35 | } 36 | 37 | if 'model_paths' in config: 38 | for key, template in path_templates.items(): 39 | if key in config['model_paths']: 40 | # only use template when the original value is empty string 41 | current_value = config['model_paths'][key] 42 | if current_value == "" or current_value is None: 43 | config['model_paths'][key] = template 44 | else: 45 | # if the original value is not empty, use template to replace the {exp_name} placeholder 46 | if isinstance(current_value, str) and '{exp_name}' in current_value: 47 | config['model_paths'][key] = current_value.format(exp_name=exp_name) 48 | 49 | return config 50 | 51 | def get(self, key: str, default=None): 52 | 53 | keys = key.split('.') 54 | value = self.config 55 | 56 | try: 57 | for k in keys: 58 | value = value[k] 59 | return value 60 | except (KeyError, TypeError): 61 | return default 62 | 63 | def get_data_config(self) -> Dict[str, Any]: 64 | return self.config.get('data', {}) 65 | 66 | def get_training_config(self) -> Dict[str, Any]: 67 | return self.config.get('training', {}) 68 | 69 | def get_model_paths(self) -> Dict[str, str]: 70 | return self.config.get('model_paths', {}) 71 | 72 | def get_experiment_config(self) -> Dict[str, Any]: 73 | return self.config.get('experiment', {}) 74 | 75 | def get_device_config(self) -> Dict[str, Any]: 76 | return self.config.get('device', {}) 77 | 78 | def get_distributed_config(self) -> Dict[str, Any]: 79 | return self.config.get('distributed', {}) 80 | 81 | def update_config(self, updates: Dict[str, Any]): 82 | 83 | def update_nested_dict(d, u): 84 | for k, v in u.items(): 85 | if isinstance(v, dict): 86 | d[k] = update_nested_dict(d.get(k, {}), v) 87 | else: 88 | d[k] = v 89 | return d 90 | 91 | self.config = update_nested_dict(self.config, updates) 92 | 93 | def save_config(self, save_path: str = None): 94 | 95 | if save_path is None: 96 | save_path = self.config_path 97 | 98 | with open(save_path, 'w', encoding='utf-8') as f: 99 | yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True, indent=2) 100 | 101 | def print_config(self): 102 | print("=" * 50) 103 | print("Current configuration:") 104 | print("=" * 50) 105 | yaml.dump(self.config, default_flow_style=False, allow_unicode=True, indent=2) 106 | print("=" * 50) 107 | 108 | 109 | class CustomFinetuneConfig: 110 | 111 | def __init__(self, config_path: str = None): 112 | 113 | if config_path is None: 114 | config_path = os.path.join(os.path.dirname(__file__), 'config.yaml') 115 | 116 | self.loader = ConfigLoader(config_path) 117 | self._load_all_configs() 118 | 119 | def _load_all_configs(self): 120 | 121 | data_config = self.loader.get_data_config() 122 | self.data_path = data_config.get('data_path') 123 | self.lookback_window = data_config.get('lookback_window', 512) 124 | self.predict_window = data_config.get('predict_window', 48) 125 | self.max_context = data_config.get('max_context', 512) 126 | self.clip = data_config.get('clip', 5.0) 127 | self.train_ratio = data_config.get('train_ratio', 0.9) 128 | self.val_ratio = data_config.get('val_ratio', 0.1) 129 | self.test_ratio = data_config.get('test_ratio', 0.0) 130 | 131 | # training configuration 132 | training_config = self.loader.get_training_config() 133 | # support training epochs of tokenizer and basemodel separately 134 | self.tokenizer_epochs = training_config.get('tokenizer_epochs', 30) 135 | self.basemodel_epochs = training_config.get('basemodel_epochs', 30) 136 | 137 | if 'epochs' in training_config and 'tokenizer_epochs' not in training_config: 138 | self.tokenizer_epochs = training_config.get('epochs', 30) 139 | if 'epochs' in training_config and 'basemodel_epochs' not in training_config: 140 | self.basemodel_epochs = training_config.get('epochs', 30) 141 | 142 | self.batch_size = training_config.get('batch_size', 160) 143 | self.log_interval = training_config.get('log_interval', 50) 144 | self.num_workers = training_config.get('num_workers', 6) 145 | self.seed = training_config.get('seed', 100) 146 | self.tokenizer_learning_rate = training_config.get('tokenizer_learning_rate', 2e-4) 147 | self.predictor_learning_rate = training_config.get('predictor_learning_rate', 4e-5) 148 | self.adam_beta1 = training_config.get('adam_beta1', 0.9) 149 | self.adam_beta2 = training_config.get('adam_beta2', 0.95) 150 | self.adam_weight_decay = training_config.get('adam_weight_decay', 0.1) 151 | self.accumulation_steps = training_config.get('accumulation_steps', 1) 152 | 153 | model_paths = self.loader.get_model_paths() 154 | self.exp_name = model_paths.get('exp_name', 'default_experiment') 155 | self.pretrained_tokenizer_path = model_paths.get('pretrained_tokenizer') 156 | self.pretrained_predictor_path = model_paths.get('pretrained_predictor') 157 | self.base_save_path = model_paths.get('base_save_path') 158 | self.tokenizer_save_name = model_paths.get('tokenizer_save_name', 'tokenizer') 159 | self.basemodel_save_name = model_paths.get('basemodel_save_name', 'basemodel') 160 | self.finetuned_tokenizer_path = model_paths.get('finetuned_tokenizer') 161 | 162 | experiment_config = self.loader.get_experiment_config() 163 | self.experiment_name = experiment_config.get('name', 'kronos_custom_finetune') 164 | self.experiment_description = experiment_config.get('description', '') 165 | self.use_comet = experiment_config.get('use_comet', False) 166 | self.train_tokenizer = experiment_config.get('train_tokenizer', True) 167 | self.train_basemodel = experiment_config.get('train_basemodel', True) 168 | self.skip_existing = experiment_config.get('skip_existing', False) 169 | 170 | unified_pretrained = experiment_config.get('pre_trained', None) 171 | self.pre_trained_tokenizer = experiment_config.get('pre_trained_tokenizer', unified_pretrained if unified_pretrained is not None else True) 172 | self.pre_trained_predictor = experiment_config.get('pre_trained_predictor', unified_pretrained if unified_pretrained is not None else True) 173 | 174 | device_config = self.loader.get_device_config() 175 | self.use_cuda = device_config.get('use_cuda', True) 176 | self.device_id = device_config.get('device_id', 0) 177 | 178 | distributed_config = self.loader.get_distributed_config() 179 | self.use_ddp = distributed_config.get('use_ddp', False) 180 | self.ddp_backend = distributed_config.get('backend', 'nccl') 181 | 182 | self._compute_full_paths() 183 | 184 | def _compute_full_paths(self): 185 | 186 | self.tokenizer_save_path = os.path.join(self.base_save_path, self.tokenizer_save_name) 187 | self.tokenizer_best_model_path = os.path.join(self.tokenizer_save_path, 'best_model') 188 | 189 | self.basemodel_save_path = os.path.join(self.base_save_path, self.basemodel_save_name) 190 | self.basemodel_best_model_path = os.path.join(self.basemodel_save_path, 'best_model') 191 | 192 | def get_tokenizer_config(self): 193 | 194 | return { 195 | 'data_path': self.data_path, 196 | 'lookback_window': self.lookback_window, 197 | 'predict_window': self.predict_window, 198 | 'max_context': self.max_context, 199 | 'clip': self.clip, 200 | 'train_ratio': self.train_ratio, 201 | 'val_ratio': self.val_ratio, 202 | 'test_ratio': self.test_ratio, 203 | 'epochs': self.tokenizer_epochs, 204 | 'batch_size': self.batch_size, 205 | 'log_interval': self.log_interval, 206 | 'num_workers': self.num_workers, 207 | 'seed': self.seed, 208 | 'learning_rate': self.tokenizer_learning_rate, 209 | 'adam_beta1': self.adam_beta1, 210 | 'adam_beta2': self.adam_beta2, 211 | 'adam_weight_decay': self.adam_weight_decay, 212 | 'accumulation_steps': self.accumulation_steps, 213 | 'pretrained_model_path': self.pretrained_tokenizer_path, 214 | 'save_path': self.tokenizer_save_path, 215 | 'use_comet': self.use_comet 216 | } 217 | 218 | def get_basemodel_config(self): 219 | 220 | return { 221 | 'data_path': self.data_path, 222 | 'lookback_window': self.lookback_window, 223 | 'predict_window': self.predict_window, 224 | 'max_context': self.max_context, 225 | 'clip': self.clip, 226 | 'train_ratio': self.train_ratio, 227 | 'val_ratio': self.val_ratio, 228 | 'test_ratio': self.test_ratio, 229 | 'epochs': self.basemodel_epochs, 230 | 'batch_size': self.batch_size, 231 | 'log_interval': self.log_interval, 232 | 'num_workers': self.num_workers, 233 | 'seed': self.seed, 234 | 'predictor_learning_rate': self.predictor_learning_rate, 235 | 'tokenizer_learning_rate': self.tokenizer_learning_rate, 236 | 'adam_beta1': self.adam_beta1, 237 | 'adam_beta2': self.adam_beta2, 238 | 'adam_weight_decay': self.adam_weight_decay, 239 | 'pretrained_tokenizer_path': self.finetuned_tokenizer_path, 240 | 'pretrained_predictor_path': self.pretrained_predictor_path, 241 | 'save_path': self.basemodel_save_path, 242 | 'use_comet': self.use_comet 243 | } 244 | 245 | def print_config_summary(self): 246 | 247 | print("=" * 60) 248 | print("Kronos finetuning configuration summary") 249 | print("=" * 60) 250 | print(f"Experiment name: {self.exp_name}") 251 | print(f"Data path: {self.data_path}") 252 | print(f"Lookback window: {self.lookback_window}") 253 | print(f"Predict window: {self.predict_window}") 254 | print(f"Tokenizer training epochs: {self.tokenizer_epochs}") 255 | print(f"Basemodel training epochs: {self.basemodel_epochs}") 256 | print(f"Batch size: {self.batch_size}") 257 | print(f"Tokenizer learning rate: {self.tokenizer_learning_rate}") 258 | print(f"Predictor learning rate: {self.predictor_learning_rate}") 259 | print(f"Train tokenizer: {self.train_tokenizer}") 260 | print(f"Train basemodel: {self.train_basemodel}") 261 | print(f"Skip existing: {self.skip_existing}") 262 | print(f"Use pre-trained tokenizer: {self.pre_trained_tokenizer}") 263 | print(f"Use pre-trained predictor: {self.pre_trained_predictor}") 264 | print(f"Base save path: {self.base_save_path}") 265 | print(f"Tokenizer save path: {self.tokenizer_save_path}") 266 | print(f"Basemodel save path: {self.basemodel_save_path}") 267 | print("=" * 60) 268 | -------------------------------------------------------------------------------- /finetune_csv/finetune_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | import random 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.distributed import DistributedSampler 11 | from time import gmtime, strftime 12 | import datetime 13 | import logging 14 | from logging.handlers import RotatingFileHandler 15 | import torch.distributed as dist 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | 18 | sys.path.append("../") 19 | from model import KronosTokenizer 20 | from finetune_base_model import CustomKlineDataset 21 | from config_loader import CustomFinetuneConfig 22 | 23 | 24 | def set_seed(seed: int, rank: int = 0): 25 | actual_seed = seed 26 | random.seed(actual_seed) 27 | np.random.seed(actual_seed) 28 | torch.manual_seed(actual_seed) 29 | if torch.cuda.is_available(): 30 | torch.cuda.manual_seed_all(actual_seed) 31 | torch.backends.cudnn.deterministic = True 32 | torch.backends.cudnn.benchmark = False 33 | 34 | 35 | def get_model_size(model: torch.nn.Module) -> str: 36 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 37 | if total_params >= 1e9: 38 | return f"{total_params / 1e9:.1f}B" 39 | elif total_params >= 1e6: 40 | return f"{total_params / 1e6:.1f}M" 41 | else: 42 | return f"{total_params / 1e3:.1f}K" 43 | 44 | 45 | def format_time(seconds: float) -> str: 46 | return str(datetime.timedelta(seconds=int(seconds))) 47 | 48 | 49 | def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger: 50 | os.makedirs(log_dir, exist_ok=True) 51 | 52 | logger = logging.getLogger(f"tokenizer_training_rank_{rank}") 53 | logger.setLevel(logging.INFO) 54 | 55 | if logger.handlers: 56 | return logger 57 | 58 | log_file = os.path.join(log_dir, f"tokenizer_training_rank_{rank}.log") 59 | file_handler = RotatingFileHandler( 60 | log_file, 61 | maxBytes=10*1024*1024, 62 | backupCount=5, 63 | encoding='utf-8' 64 | ) 65 | file_handler.setLevel(logging.INFO) 66 | 67 | console_handler = None 68 | if rank == 0: 69 | console_handler = logging.StreamHandler() 70 | console_handler.setLevel(logging.INFO) 71 | 72 | formatter = logging.Formatter( 73 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s', 74 | datefmt='%Y-%m-%d %H:%M:%S' 75 | ) 76 | file_handler.setFormatter(formatter) 77 | if console_handler is not None: 78 | console_handler.setFormatter(formatter) 79 | 80 | logger.addHandler(file_handler) 81 | if console_handler is not None: 82 | logger.addHandler(console_handler) 83 | 84 | logger.info(f"=== Tokenizer Training Started ===") 85 | logger.info(f"Experiment Name: {exp_name}") 86 | logger.info(f"Log Directory: {log_dir}") 87 | logger.info(f"Rank: {rank}") 88 | logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") 89 | 90 | return logger 91 | 92 | 93 | def create_dataloaders(config): 94 | if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: 95 | print("Creating tokenizer training data loaders...") 96 | 97 | train_dataset = CustomKlineDataset( 98 | data_path=config.data_path, 99 | data_type="train", 100 | lookback_window=config.lookback_window, 101 | predict_window=config.predict_window, 102 | clip=config.clip, 103 | seed=config.seed, 104 | train_ratio=config.train_ratio, 105 | val_ratio=config.val_ratio, 106 | test_ratio=config.test_ratio 107 | ) 108 | 109 | val_dataset = CustomKlineDataset( 110 | data_path=config.data_path, 111 | data_type="val", 112 | lookback_window=config.lookback_window, 113 | predict_window=config.predict_window, 114 | clip=config.clip, 115 | seed=config.seed + 1, 116 | train_ratio=config.train_ratio, 117 | val_ratio=config.val_ratio, 118 | test_ratio=config.test_ratio 119 | ) 120 | 121 | use_ddp = dist.is_available() and dist.is_initialized() 122 | train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None 123 | val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None 124 | 125 | train_loader = DataLoader( 126 | train_dataset, 127 | batch_size=config.batch_size, 128 | shuffle=(train_sampler is None), 129 | num_workers=config.num_workers, 130 | pin_memory=True, 131 | drop_last=True, 132 | sampler=train_sampler 133 | ) 134 | 135 | val_loader = DataLoader( 136 | val_dataset, 137 | batch_size=config.batch_size, 138 | shuffle=False, 139 | num_workers=config.num_workers, 140 | pin_memory=True, 141 | drop_last=False, 142 | sampler=val_sampler 143 | ) 144 | 145 | if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: 146 | print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}") 147 | 148 | return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler 149 | 150 | 151 | def train_tokenizer(model, device, config, save_dir, logger): 152 | logger.info("Starting tokenizer training...") 153 | use_ddp = dist.is_available() and dist.is_initialized() 154 | rank = dist.get_rank() if use_ddp else 0 155 | world_size = dist.get_world_size() if use_ddp else 1 156 | 157 | train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config) 158 | 159 | optimizer = torch.optim.AdamW( 160 | model.parameters(), 161 | lr=config.tokenizer_learning_rate, 162 | weight_decay=config.adam_weight_decay 163 | ) 164 | 165 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 166 | optimizer, 167 | max_lr=config.tokenizer_learning_rate, 168 | steps_per_epoch=len(train_loader), 169 | epochs=config.tokenizer_epochs, 170 | pct_start=0.03, 171 | div_factor=10 172 | ) 173 | 174 | if use_ddp: 175 | local_rank = int(os.environ.get("LOCAL_RANK", "0")) 176 | model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) 177 | 178 | best_val_loss = float("inf") 179 | batch_idx_global = 0 180 | 181 | accumulation_steps = getattr(config, 'accumulation_steps', 1) 182 | 183 | for epoch in range(config.tokenizer_epochs): 184 | epoch_start_time = time.time() 185 | model.train() 186 | 187 | train_dataset.set_epoch_seed(epoch * 10000) 188 | val_dataset.set_epoch_seed(0) 189 | if train_sampler is not None: 190 | train_sampler.set_epoch(epoch) 191 | 192 | for batch_idx, (ori_batch_x, _) in enumerate(train_loader): 193 | ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) 194 | 195 | current_batch_total_loss = 0.0 196 | for j in range(accumulation_steps): 197 | start_idx = j * (ori_batch_x.shape[0] // accumulation_steps) 198 | end_idx = (j + 1) * (ori_batch_x.shape[0] // accumulation_steps) 199 | batch_x = ori_batch_x[start_idx:end_idx] 200 | 201 | zs, bsq_loss, _, _ = (model.module if use_ddp else model)(batch_x) 202 | z_pre, z = zs 203 | 204 | recon_loss_pre = F.mse_loss(z_pre, batch_x) 205 | recon_loss_all = F.mse_loss(z, batch_x) 206 | recon_loss = recon_loss_pre + recon_loss_all 207 | loss = (recon_loss + bsq_loss) / 2 208 | 209 | loss_scaled = loss / accumulation_steps 210 | current_batch_total_loss += loss.item() 211 | loss_scaled.backward() 212 | 213 | torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=2.0) 214 | optimizer.step() 215 | scheduler.step() 216 | optimizer.zero_grad() 217 | 218 | if (batch_idx_global + 1) % config.log_interval == 0: 219 | avg_loss = current_batch_total_loss / accumulation_steps 220 | lr = optimizer.param_groups[0]["lr"] 221 | log_msg = (f"[Epoch {epoch+1}/{config.tokenizer_epochs}, Step {batch_idx+1}/{len(train_loader)}] " 222 | f"LR: {lr:.6f}, Loss: {avg_loss:.4f}") 223 | logger.info(log_msg) 224 | if rank == 0: 225 | print(log_msg) 226 | 227 | detail_msg = (f" - VQ Loss: {bsq_loss.item():.4f}\n" 228 | f" - Recon Loss Pre: {recon_loss_pre.item():.4f}\n" 229 | f" - Recon Loss All: {recon_loss_all.item():.4f}") 230 | logger.info(detail_msg) 231 | if rank == 0: 232 | print(detail_msg) 233 | 234 | batch_idx_global += 1 235 | 236 | model.eval() 237 | tot_val_loss_sum_rank = 0.0 238 | val_sample_count_rank = 0 239 | 240 | with torch.no_grad(): 241 | for ori_batch_x, _ in val_loader: 242 | ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) 243 | zs, _, _, _ = (model.module if use_ddp else model)(ori_batch_x) 244 | _, z = zs 245 | val_loss_item = F.mse_loss(z, ori_batch_x) 246 | 247 | tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0) 248 | val_sample_count_rank += ori_batch_x.size(0) 249 | 250 | if use_ddp: 251 | tensor_sum = torch.tensor([tot_val_loss_sum_rank, val_sample_count_rank], dtype=torch.float64, device=device) 252 | dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM) 253 | tot_val_loss_all = tensor_sum[0].item() 254 | val_count_all = int(tensor_sum[1].item()) 255 | avg_val_loss = (tot_val_loss_all / val_count_all) if val_count_all > 0 else 0.0 256 | else: 257 | avg_val_loss = tot_val_loss_sum_rank / val_sample_count_rank if val_sample_count_rank > 0 else 0 258 | 259 | epoch_time = time.time() - epoch_start_time 260 | epoch_summary = (f"\n--- Epoch {epoch+1}/{config.tokenizer_epochs} Summary ---\n" 261 | f"Validation Loss: {avg_val_loss:.4f}\n" 262 | f"Epoch Time: {format_time(epoch_time)}\n" 263 | f"Total Training Time: {format_time(time.time() - epoch_start_time)}\n") 264 | logger.info(epoch_summary) 265 | if rank == 0: 266 | print(epoch_summary) 267 | 268 | if avg_val_loss < best_val_loss: 269 | best_val_loss = avg_val_loss 270 | if rank == 0: 271 | model_save_path = os.path.join(save_dir, "best_model") 272 | os.makedirs(model_save_path, exist_ok=True) 273 | (model.module if use_ddp else model).save_pretrained(model_save_path) 274 | save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})" 275 | logger.info(save_msg) 276 | print(save_msg) 277 | 278 | return best_val_loss 279 | 280 | 281 | def main(): 282 | import argparse 283 | 284 | parser = argparse.ArgumentParser(description='Kronos Tokenizer Fine-tuning Training') 285 | parser.add_argument('--config', type=str, default='config.yaml', 286 | help='Configuration file path (default: config.yaml)') 287 | args = parser.parse_args() 288 | 289 | config = CustomFinetuneConfig(args.config) 290 | 291 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 292 | print(f"Using device: {device}") 293 | 294 | config = CustomFinetuneConfig(args.config) 295 | 296 | os.makedirs(config.tokenizer_save_path, exist_ok=True) 297 | 298 | log_dir = os.path.join(config.base_save_path, "logs") 299 | logger = setup_logging(config.exp_name, log_dir, 0) 300 | 301 | set_seed(config.seed) 302 | 303 | # 加载预训练tokenizer 304 | if getattr(config, 'pre_trained_tokenizer', True): 305 | logger.info("Loading pretrained tokenizer...") 306 | print("Loading pretrained tokenizer...") 307 | tokenizer = KronosTokenizer.from_pretrained(config.pretrained_tokenizer_path) 308 | else: 309 | print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture") 310 | import json, os 311 | cfg_path = os.path.join(config.pretrained_tokenizer_path, 'config.json') 312 | with open(cfg_path, 'r') as f: 313 | arch = json.load(f) 314 | tokenizer = KronosTokenizer( 315 | d_in=arch.get('d_in', 6), 316 | d_model=arch.get('d_model', 256), 317 | n_heads=arch.get('n_heads', 4), 318 | ff_dim=arch.get('ff_dim', 512), 319 | n_enc_layers=arch.get('n_enc_layers', 4), 320 | n_dec_layers=arch.get('n_dec_layers', 4), 321 | ffn_dropout_p=arch.get('ffn_dropout_p', 0.0), 322 | attn_dropout_p=arch.get('attn_dropout_p', 0.0), 323 | resid_dropout_p=arch.get('resid_dropout_p', 0.0), 324 | s1_bits=arch.get('s1_bits', 10), 325 | s2_bits=arch.get('s2_bits', 10), 326 | beta=arch.get('beta', 0.05), 327 | gamma0=arch.get('gamma0', 1.0), 328 | gamma=arch.get('gamma', 1.1), 329 | zeta=arch.get('zeta', 0.05), 330 | group_size=arch.get('group_size', 4) 331 | ) 332 | tokenizer = tokenizer.to(device) 333 | 334 | model_size = get_model_size(tokenizer) 335 | logger.info(f"Tokenizer parameters: {model_size}") 336 | print(f"Tokenizer parameters: {model_size}") 337 | 338 | logger.info("=== Training Configuration ===") 339 | logger.info(f"Data path: {config.data_path}") 340 | logger.info(f"Lookback window: {config.lookback_window}") 341 | logger.info(f"Predict window: {config.predict_window}") 342 | logger.info(f"Batch size: {config.batch_size}") 343 | logger.info(f"Learning rate: {config.tokenizer_learning_rate}") 344 | logger.info(f"Training epochs: {config.tokenizer_epochs}") 345 | logger.info(f"Device: {device}") 346 | logger.info(f"Distributed training: False") 347 | 348 | logger.info("Starting tokenizer fine-tuning training...") 349 | print("Starting tokenizer fine-tuning training...") 350 | best_val_loss = train_tokenizer(tokenizer, device, config, config.tokenizer_save_path, logger) 351 | 352 | final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.tokenizer_save_path}" 353 | logger.info(final_msg) 354 | print(final_msg) 355 | 356 | 357 | if __name__ == "__main__": 358 | main() 359 | 360 | -------------------------------------------------------------------------------- /finetune/qlib_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import pickle 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader 11 | from tqdm import trange, tqdm 12 | from matplotlib import pyplot as plt 13 | 14 | import qlib 15 | from qlib.config import REG_CN 16 | from qlib.backtest import backtest, executor, CommonInfrastructure 17 | from qlib.contrib.evaluate import risk_analysis 18 | from qlib.contrib.strategy import TopkDropoutStrategy 19 | from qlib.utils import flatten_dict 20 | from qlib.utils.time import Freq 21 | 22 | # Ensure project root is in the Python path 23 | sys.path.append("../") 24 | from config import Config 25 | from model.kronos import Kronos, KronosTokenizer, auto_regressive_inference 26 | 27 | 28 | # ================================================================================= 29 | # 1. Data Loading and Processing for Inference 30 | # ================================================================================= 31 | 32 | class QlibTestDataset(Dataset): 33 | """ 34 | PyTorch Dataset for handling Qlib test data, specifically for inference. 35 | 36 | This dataset iterates through all possible sliding windows sequentially. It also 37 | yields metadata like symbol and timestamp, which are crucial for mapping 38 | predictions back to the original time series. 39 | """ 40 | 41 | def __init__(self, data: dict, config: Config): 42 | self.data = data 43 | self.config = config 44 | self.window_size = config.lookback_window + config.predict_window 45 | self.symbols = list(self.data.keys()) 46 | self.feature_list = config.feature_list 47 | self.time_feature_list = config.time_feature_list 48 | self.indices = [] 49 | 50 | print("Preprocessing and building indices for test dataset...") 51 | for symbol in self.symbols: 52 | df = self.data[symbol].reset_index() 53 | # Generate time features on-the-fly 54 | df['minute'] = df['datetime'].dt.minute 55 | df['hour'] = df['datetime'].dt.hour 56 | df['weekday'] = df['datetime'].dt.weekday 57 | df['day'] = df['datetime'].dt.day 58 | df['month'] = df['datetime'].dt.month 59 | self.data[symbol] = df # Store preprocessed dataframe 60 | 61 | num_samples = len(df) - self.window_size + 1 62 | if num_samples > 0: 63 | for i in range(num_samples): 64 | timestamp = df.iloc[i + self.config.lookback_window - 1]['datetime'] 65 | self.indices.append((symbol, i, timestamp)) 66 | 67 | def __len__(self) -> int: 68 | return len(self.indices) 69 | 70 | def __getitem__(self, idx: int): 71 | symbol, start_idx, timestamp = self.indices[idx] 72 | df = self.data[symbol] 73 | 74 | context_end = start_idx + self.config.lookback_window 75 | predict_end = context_end + self.config.predict_window 76 | 77 | context_df = df.iloc[start_idx:context_end] 78 | predict_df = df.iloc[context_end:predict_end] 79 | 80 | x = context_df[self.feature_list].values.astype(np.float32) 81 | x_stamp = context_df[self.time_feature_list].values.astype(np.float32) 82 | y_stamp = predict_df[self.time_feature_list].values.astype(np.float32) 83 | 84 | # Instance-level normalization, consistent with training 85 | x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) 86 | x = (x - x_mean) / (x_std + 1e-5) 87 | x = np.clip(x, -self.config.clip, self.config.clip) 88 | 89 | return torch.from_numpy(x), torch.from_numpy(x_stamp), torch.from_numpy(y_stamp), symbol, timestamp 90 | 91 | 92 | # ================================================================================= 93 | # 2. Backtesting Logic 94 | # ================================================================================= 95 | 96 | class QlibBacktest: 97 | """ 98 | A wrapper class for conducting backtesting experiments using Qlib. 99 | """ 100 | 101 | def __init__(self, config: Config): 102 | self.config = config 103 | self.initialize_qlib() 104 | 105 | def initialize_qlib(self): 106 | """Initializes the Qlib environment.""" 107 | print("Initializing Qlib for backtesting...") 108 | qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN) 109 | 110 | def run_single_backtest(self, signal_series: pd.Series) -> pd.DataFrame: 111 | """ 112 | Runs a single backtest for a given prediction signal. 113 | 114 | Args: 115 | signal_series (pd.Series): A pandas Series with a MultiIndex 116 | (instrument, datetime) and prediction scores. 117 | Returns: 118 | pd.DataFrame: A DataFrame containing the performance report. 119 | """ 120 | strategy = TopkDropoutStrategy( 121 | topk=self.config.backtest_n_symbol_hold, 122 | n_drop=self.config.backtest_n_symbol_drop, 123 | hold_thresh=self.config.backtest_hold_thresh, 124 | signal=signal_series, 125 | ) 126 | executor_config = { 127 | "time_per_step": "day", 128 | "generate_portfolio_metrics": True, 129 | "delay_execution": True, 130 | } 131 | backtest_config = { 132 | "start_time": self.config.backtest_time_range[0], 133 | "end_time": self.config.backtest_time_range[1], 134 | "account": 100_000_000, 135 | "benchmark": self.config.backtest_benchmark, 136 | "exchange_kwargs": { 137 | "freq": "day", "limit_threshold": 0.095, "deal_price": "open", 138 | "open_cost": 0.001, "close_cost": 0.0015, "min_cost": 5, 139 | }, 140 | "executor": executor.SimulatorExecutor(**executor_config), 141 | } 142 | 143 | portfolio_metric_dict, _ = backtest(strategy=strategy, **backtest_config) 144 | analysis_freq = "{0}{1}".format(*Freq.parse("day")) 145 | report, _ = portfolio_metric_dict.get(analysis_freq) 146 | 147 | # --- Analysis and Reporting --- 148 | analysis = { 149 | "excess_return_without_cost": risk_analysis(report["return"] - report["bench"], freq=analysis_freq), 150 | "excess_return_with_cost": risk_analysis(report["return"] - report["bench"] - report["cost"], freq=analysis_freq), 151 | } 152 | print("\n--- Backtest Analysis ---") 153 | print("Benchmark Return:", risk_analysis(report["bench"], freq=analysis_freq), sep='\n') 154 | print("\nExcess Return (w/o cost):", analysis["excess_return_without_cost"], sep='\n') 155 | print("\nExcess Return (w/ cost):", analysis["excess_return_with_cost"], sep='\n') 156 | 157 | report_df = pd.DataFrame({ 158 | "cum_bench": report["bench"].cumsum(), 159 | "cum_return_w_cost": (report["return"] - report["cost"]).cumsum(), 160 | "cum_ex_return_w_cost": (report["return"] - report["bench"] - report["cost"]).cumsum(), 161 | }) 162 | return report_df 163 | 164 | def run_and_plot_results(self, signals: dict[str, pd.DataFrame]): 165 | """ 166 | Runs backtests for multiple signals and plots the cumulative return curves. 167 | 168 | Args: 169 | signals (dict[str, pd.DataFrame]): A dictionary where keys are signal names 170 | and values are prediction DataFrames. 171 | """ 172 | return_df, ex_return_df, bench_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame() 173 | 174 | for signal_name, pred_df in signals.items(): 175 | print(f"\nBacktesting signal: {signal_name}...") 176 | pred_series = pred_df.stack() 177 | pred_series.index.names = ['datetime', 'instrument'] 178 | pred_series = pred_series.swaplevel().sort_index() 179 | report_df = self.run_single_backtest(pred_series) 180 | 181 | return_df[signal_name] = report_df['cum_return_w_cost'] 182 | ex_return_df[signal_name] = report_df['cum_ex_return_w_cost'] 183 | if 'return' not in bench_df: 184 | bench_df['return'] = report_df['cum_bench'] 185 | 186 | # Plotting results 187 | fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True) 188 | return_df.plot(ax=axes[0], title='Cumulative Return with Cost', grid=True) 189 | axes[0].plot(bench_df['return'], label=self.config.instrument.upper(), color='black', linestyle='--') 190 | axes[0].legend() 191 | axes[0].set_ylabel("Cumulative Return") 192 | 193 | ex_return_df.plot(ax=axes[1], title='Cumulative Excess Return with Cost', grid=True) 194 | axes[1].legend() 195 | axes[1].set_xlabel("Date") 196 | axes[1].set_ylabel("Cumulative Excess Return") 197 | 198 | plt.tight_layout() 199 | plt.savefig("../figures/backtest_result_example.png", dpi=200) 200 | plt.show() 201 | 202 | 203 | # ================================================================================= 204 | # 3. Inference Logic 205 | # ================================================================================= 206 | 207 | def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]: 208 | """Loads the fine-tuned tokenizer and predictor model.""" 209 | device = torch.device(config['device']) 210 | print(f"Loading models onto device: {device}...") 211 | tokenizer = KronosTokenizer.from_pretrained(config['tokenizer_path']).to(device).eval() 212 | model = Kronos.from_pretrained(config['model_path']).to(device).eval() 213 | return tokenizer, model 214 | 215 | 216 | def collate_fn_for_inference(batch): 217 | """ 218 | Custom collate function to handle batches containing Tensors, strings, and Timestamps. 219 | 220 | Args: 221 | batch (list): A list of samples, where each sample is the tuple returned by 222 | QlibTestDataset.__getitem__. 223 | 224 | Returns: 225 | A single tuple containing the batched data. 226 | """ 227 | # Unzip the list of samples into separate lists for each data type 228 | x, x_stamp, y_stamp, symbols, timestamps = zip(*batch) 229 | 230 | # Stack the tensors to create a batch 231 | x_batch = torch.stack(x, dim=0) 232 | x_stamp_batch = torch.stack(x_stamp, dim=0) 233 | y_stamp_batch = torch.stack(y_stamp, dim=0) 234 | 235 | # Return the strings and timestamps as lists 236 | return x_batch, x_stamp_batch, y_stamp_batch, list(symbols), list(timestamps) 237 | 238 | 239 | def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFrame]: 240 | """ 241 | Runs inference on the test dataset to generate prediction signals. 242 | 243 | Args: 244 | config (dict): A dictionary containing inference parameters. 245 | test_data (dict): The raw test data loaded from a pickle file. 246 | 247 | Returns: 248 | A dictionary where keys are signal types (e.g., 'mean', 'last') and 249 | values are DataFrames of predictions (datetime index, symbol columns). 250 | """ 251 | tokenizer, model = load_models(config) 252 | device = torch.device(config['device']) 253 | 254 | # Use the Dataset and DataLoader for efficient batching and processing 255 | dataset = QlibTestDataset(data=test_data, config=Config()) 256 | loader = DataLoader( 257 | dataset, 258 | batch_size=config['batch_size'] // config['sample_count'], 259 | shuffle=False, 260 | num_workers=os.cpu_count() // 2, 261 | collate_fn=collate_fn_for_inference 262 | ) 263 | 264 | results = defaultdict(list) 265 | with torch.no_grad(): 266 | for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"): 267 | preds = auto_regressive_inference( 268 | tokenizer, model, x.to(device), x_stamp.to(device), y_stamp.to(device), 269 | max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'], 270 | T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count'] 271 | ) 272 | # You can try commenting on this line to keep the history data 273 | preds = preds[:, -config['pred_len']:, :] 274 | 275 | # The 'close' price is at index 3 in `feature_list` 276 | last_day_close = x[:, -1, 3].numpy() 277 | signals = { 278 | 'last': preds[:, -1, 3] - last_day_close, 279 | 'mean': np.mean(preds[:, :, 3], axis=1) - last_day_close, 280 | 'max': np.max(preds[:, :, 3], axis=1) - last_day_close, 281 | 'min': np.min(preds[:, :, 3], axis=1) - last_day_close, 282 | } 283 | 284 | for i in range(len(symbols)): 285 | for sig_type, sig_values in signals.items(): 286 | results[sig_type].append((timestamps[i], symbols[i], sig_values[i])) 287 | 288 | print("Post-processing predictions into DataFrames...") 289 | prediction_dfs = {} 290 | for sig_type, records in results.items(): 291 | df = pd.DataFrame(records, columns=['datetime', 'instrument', 'score']) 292 | pivot_df = df.pivot_table(index='datetime', columns='instrument', values='score') 293 | prediction_dfs[sig_type] = pivot_df.sort_index() 294 | 295 | return prediction_dfs 296 | 297 | 298 | # ================================================================================= 299 | # 4. Main Execution 300 | # ================================================================================= 301 | 302 | def main(): 303 | """Main function to set up config, run inference, and execute backtesting.""" 304 | parser = argparse.ArgumentParser(description="Run Kronos Inference and Backtesting") 305 | parser.add_argument("--device", type=str, default="cuda:1", help="Device for inference (e.g., 'cuda:0', 'cpu')") 306 | args = parser.parse_args() 307 | 308 | # --- 1. Configuration Setup --- 309 | base_config = Config() 310 | 311 | # Create a dedicated dictionary for this run's configuration 312 | run_config = { 313 | 'device': args.device, 314 | 'data_path': base_config.dataset_path, 315 | 'result_save_path': base_config.backtest_result_path, 316 | 'result_name': base_config.backtest_save_folder_name, 317 | 'tokenizer_path': base_config.finetuned_tokenizer_path, 318 | 'model_path': base_config.finetuned_predictor_path, 319 | 'max_context': base_config.max_context, 320 | 'pred_len': base_config.predict_window, 321 | 'clip': base_config.clip, 322 | 'T': base_config.inference_T, 323 | 'top_k': base_config.inference_top_k, 324 | 'top_p': base_config.inference_top_p, 325 | 'sample_count': base_config.inference_sample_count, 326 | 'batch_size': base_config.backtest_batch_size, 327 | } 328 | 329 | print("--- Running with Configuration ---") 330 | for key, val in run_config.items(): 331 | print(f"{key:>20}: {val}") 332 | print("-" * 35) 333 | 334 | # --- 2. Load Data --- 335 | test_data_path = os.path.join(run_config['data_path'], "test_data.pkl") 336 | print(f"Loading test data from {test_data_path}...") 337 | with open(test_data_path, 'rb') as f: 338 | test_data = pickle.load(f) 339 | print(test_data) 340 | # --- 3. Generate Predictions --- 341 | model_preds = generate_predictions(run_config, test_data) 342 | 343 | # --- 4. Save Predictions --- 344 | save_dir = os.path.join(run_config['result_save_path'], run_config['result_name']) 345 | os.makedirs(save_dir, exist_ok=True) 346 | predictions_file = os.path.join(save_dir, "predictions.pkl") 347 | print(f"Saving prediction signals to {predictions_file}...") 348 | with open(predictions_file, 'wb') as f: 349 | pickle.dump(model_preds, f) 350 | 351 | # --- 5. Run Backtesting --- 352 | with open(predictions_file, 'rb') as f: 353 | model_preds = pickle.load(f) 354 | 355 | backtester = QlibBacktest(base_config) 356 | backtester.run_and_plot_results(model_preds) 357 | 358 | 359 | if __name__ == '__main__': 360 | main() 361 | 362 | 363 | -------------------------------------------------------------------------------- /finetune_csv/train_sequential.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | import torch.distributed as dist 9 | 10 | sys.path.append('../') 11 | from model import Kronos, KronosTokenizer, KronosPredictor 12 | 13 | from config_loader import CustomFinetuneConfig 14 | from finetune_tokenizer import train_tokenizer, set_seed, setup_logging as setup_tokenizer_logging 15 | from finetune_base_model import train_model, create_dataloaders, setup_logging as setup_basemodel_logging 16 | 17 | 18 | class SequentialTrainer: 19 | 20 | def __init__(self, config_path: str = None): 21 | self.config = CustomFinetuneConfig(config_path) 22 | self.rank = int(os.environ.get("RANK", "0")) 23 | self.world_size = int(os.environ.get("WORLD_SIZE", "1")) 24 | self.local_rank = int(os.environ.get("LOCAL_RANK", str(self.config.device_id if hasattr(self.config, 'device_id') else 0))) 25 | self.device = self._setup_device() 26 | 27 | self.config.print_config_summary() 28 | 29 | def _setup_device(self): 30 | if self.config.use_cuda and torch.cuda.is_available(): 31 | torch.cuda.set_device(self.local_rank) 32 | device = torch.device(f"cuda:{self.local_rank}") 33 | else: 34 | device = torch.device("cpu") 35 | 36 | if self.rank == 0: 37 | print(f"Using device: {device} (rank={self.rank}, world_size={self.world_size}, local_rank={self.local_rank})") 38 | return device 39 | 40 | def _setup_distributed(self): 41 | if self.world_size > 1 and torch.cuda.is_available(): 42 | backend = os.environ.get("DIST_BACKEND", "nccl").lower() 43 | if not dist.is_initialized(): 44 | dist.init_process_group(backend=backend) 45 | if self.rank == 0: 46 | print(f"Distributed training initialized: backend={backend}, world_size={self.world_size}") 47 | else: 48 | if self.rank == 0: 49 | print("Distributed training not enabled, using single GPU/CPU training") 50 | 51 | def _check_existing_models(self): 52 | tokenizer_exists = os.path.exists(self.config.tokenizer_best_model_path) 53 | basemodel_exists = os.path.exists(self.config.basemodel_best_model_path) 54 | 55 | print(f"Tokenizer model exists: {tokenizer_exists}") 56 | print(f"Basemodel model exists: {basemodel_exists}") 57 | 58 | return tokenizer_exists, basemodel_exists 59 | 60 | def _create_directories(self): 61 | os.makedirs(self.config.tokenizer_save_path, exist_ok=True) 62 | os.makedirs(self.config.basemodel_save_path, exist_ok=True) 63 | print(f"Created directory: {self.config.tokenizer_save_path}") 64 | print(f"Created directory: {self.config.basemodel_save_path}") 65 | 66 | def train_tokenizer_phase(self): 67 | print("\n" + "="*60) 68 | print("Starting Tokenizer Fine-tuning Phase") 69 | print("="*60) 70 | 71 | tokenizer_exists, _ = self._check_existing_models() 72 | if tokenizer_exists and self.config.skip_existing: 73 | print("Tokenizer model already exists, skipping training") 74 | return True 75 | 76 | log_dir = os.path.join(self.config.base_save_path, "logs") 77 | logger = setup_tokenizer_logging(self.config.exp_name, log_dir, self.rank) 78 | 79 | set_seed(self.config.seed) 80 | 81 | if getattr(self.config, 'pre_trained_tokenizer', True): 82 | logger.info("Loading pretrained tokenizer...") 83 | if self.rank == 0: 84 | print("Loading pretrained tokenizer...") 85 | tokenizer = KronosTokenizer.from_pretrained(self.config.pretrained_tokenizer_path) 86 | else: 87 | if self.rank == 0: 88 | print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture") 89 | import json 90 | cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json') 91 | with open(cfg_path, 'r') as f: 92 | arch = json.load(f) 93 | tokenizer = KronosTokenizer( 94 | d_in=arch.get('d_in', 6), 95 | d_model=arch.get('d_model', 256), 96 | n_heads=arch.get('n_heads', 4), 97 | ff_dim=arch.get('ff_dim', 512), 98 | n_enc_layers=arch.get('n_enc_layers', 4), 99 | n_dec_layers=arch.get('n_dec_layers', 4), 100 | ffn_dropout_p=arch.get('ffn_dropout_p', 0.0), 101 | attn_dropout_p=arch.get('attn_dropout_p', 0.0), 102 | resid_dropout_p=arch.get('resid_dropout_p', 0.0), 103 | s1_bits=arch.get('s1_bits', 10), 104 | s2_bits=arch.get('s2_bits', 10), 105 | beta=arch.get('beta', 0.05), 106 | gamma0=arch.get('gamma0', 1.0), 107 | gamma=arch.get('gamma', 1.1), 108 | zeta=arch.get('zeta', 0.05), 109 | group_size=arch.get('group_size', 4) 110 | ) 111 | tokenizer = tokenizer.to(self.device) 112 | 113 | model_size = sum(p.numel() for p in tokenizer.parameters()) 114 | logger.info(f"Tokenizer parameters: {model_size:,}") 115 | if self.rank == 0: 116 | print(f"Tokenizer parameters: {model_size:,}") 117 | 118 | logger.info("=== Training Configuration ===") 119 | logger.info(f"Data path: {self.config.data_path}") 120 | logger.info(f"Lookback window: {self.config.lookback_window}") 121 | logger.info(f"Predict window: {self.config.predict_window}") 122 | logger.info(f"Batch size: {self.config.batch_size}") 123 | logger.info(f"Learning rate: {self.config.tokenizer_learning_rate}") 124 | logger.info(f"Training epochs: {self.config.tokenizer_epochs}") 125 | logger.info(f"Device: {self.device}") 126 | logger.info(f"Distributed training: False") 127 | 128 | logger.info("Starting tokenizer fine-tuning training...") 129 | if self.rank == 0: 130 | print("Starting tokenizer fine-tuning training...") 131 | start_time = time.time() 132 | best_val_loss = train_tokenizer( 133 | tokenizer, 134 | self.device, 135 | self.config, 136 | self.config.tokenizer_save_path, 137 | logger, 138 | ) 139 | training_time = time.time() - start_time 140 | 141 | final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.tokenizer_save_path}" 142 | logger.info(final_msg) 143 | if self.rank == 0: 144 | print(f"\n{final_msg}") 145 | 146 | return True 147 | 148 | def train_basemodel_phase(self): 149 | print("\n" + "="*60) 150 | print("Starting Basemodel Fine-tuning Phase") 151 | print("="*60) 152 | 153 | if getattr(self.config, 'pre_trained_tokenizer', True): 154 | if not os.path.exists(self.config.finetuned_tokenizer_path): 155 | raise FileNotFoundError(f"Fine-tuned tokenizer does not exist: {self.config.finetuned_tokenizer_path}") 156 | 157 | _, basemodel_exists = self._check_existing_models() 158 | if basemodel_exists and self.config.skip_existing: 159 | print("Basemodel model already exists, skipping training") 160 | return True 161 | 162 | log_dir = os.path.join(self.config.base_save_path, "logs") 163 | logger = setup_basemodel_logging(self.config.exp_name, log_dir, self.rank) 164 | 165 | set_seed(self.config.seed) 166 | 167 | if getattr(self.config, 'pre_trained_tokenizer', True): 168 | logger.info("Loading fine-tuned tokenizer...") 169 | if self.rank == 0: 170 | print("Loading fine-tuned tokenizer...") 171 | tokenizer = KronosTokenizer.from_pretrained(self.config.finetuned_tokenizer_path) 172 | else: 173 | if self.rank == 0: 174 | print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for Predictor training") 175 | import json 176 | cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json') 177 | with open(cfg_path, 'r') as f: 178 | arch = json.load(f) 179 | tokenizer = KronosTokenizer( 180 | d_in=arch.get('d_in', 6), 181 | d_model=arch.get('d_model', 256), 182 | n_heads=arch.get('n_heads', 4), 183 | ff_dim=arch.get('ff_dim', 512), 184 | n_enc_layers=arch.get('n_enc_layers', 4), 185 | n_dec_layers=arch.get('n_dec_layers', 4), 186 | ffn_dropout_p=arch.get('ffn_dropout_p', 0.0), 187 | attn_dropout_p=arch.get('attn_dropout_p', 0.0), 188 | resid_dropout_p=arch.get('resid_dropout_p', 0.0), 189 | s1_bits=arch.get('s1_bits', 10), 190 | s2_bits=arch.get('s2_bits', 10), 191 | beta=arch.get('beta', 0.05), 192 | gamma0=arch.get('gamma0', 1.0), 193 | gamma=arch.get('gamma', 1.1), 194 | zeta=arch.get('zeta', 0.05), 195 | group_size=arch.get('group_size', 4) 196 | ) 197 | tokenizer = tokenizer.to(self.device) 198 | 199 | if getattr(self.config, 'pre_trained_predictor', True): 200 | logger.info("Loading pretrained predictor...") 201 | if self.rank == 0: 202 | print("Loading pretrained predictor...") 203 | model = Kronos.from_pretrained(self.config.pretrained_predictor_path) 204 | else: 205 | if self.rank == 0: 206 | print("pre_trained_predictor=False, randomly initializing Predictor architecture") 207 | import json 208 | cfg_path = os.path.join(self.config.pretrained_predictor_path, 'config.json') 209 | with open(cfg_path, 'r') as f: 210 | arch = json.load(f) 211 | print("model_config: ", arch) 212 | model = Kronos( 213 | s1_bits=arch.get('s1_bits', 10), 214 | s2_bits=arch.get('s2_bits', 10), 215 | n_layers=arch.get('n_layers', 12), 216 | d_model=arch.get('d_model', 832), 217 | n_heads=arch.get('n_heads', 16), 218 | ff_dim=arch.get('ff_dim', 2048), 219 | ffn_dropout_p=arch.get('ffn_dropout_p', 0.2), 220 | attn_dropout_p=arch.get('attn_dropout_p', 0.0), 221 | resid_dropout_p=arch.get('resid_dropout_p', 0.2), 222 | token_dropout_p=arch.get('token_dropout_p', 0.0), 223 | learn_te=arch.get('learn_te', True) 224 | ) 225 | model = model.to(self.device) 226 | 227 | model_size = sum(p.numel() for p in model.parameters()) 228 | logger.info(f"Model parameters: {model_size:,}") 229 | if self.rank == 0: 230 | print(f"Model parameters: {model_size:,}") 231 | 232 | logger.info("=== Training Configuration ===") 233 | logger.info(f"Data path: {self.config.data_path}") 234 | logger.info(f"Lookback window: {self.config.lookback_window}") 235 | logger.info(f"Predict window: {self.config.predict_window}") 236 | logger.info(f"Batch size: {self.config.batch_size}") 237 | logger.info(f"Learning rate: {self.config.predictor_learning_rate}") 238 | logger.info(f"Training epochs: {self.config.basemodel_epochs}") 239 | logger.info(f"Device: {self.device}") 240 | logger.info(f"Tokenizer path: {self.config.finetuned_tokenizer_path}") 241 | logger.info(f"Pretrained model path: {self.config.pretrained_predictor_path}") 242 | 243 | logger.info("Starting fine-tuning training...") 244 | if self.rank == 0: 245 | print("Starting fine-tuning training...") 246 | start_time = time.time() 247 | best_val_loss = train_model( 248 | model, 249 | tokenizer, 250 | self.device, 251 | self.config, 252 | self.config.basemodel_save_path, 253 | logger, 254 | ) 255 | training_time = time.time() - start_time 256 | 257 | final_msg = f"Basemodel training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.basemodel_save_path}" 258 | logger.info(final_msg) 259 | if self.rank == 0: 260 | print(f"\n{final_msg}") 261 | 262 | return True 263 | 264 | def run_training(self): 265 | if self.rank == 0: 266 | print("Starting Kronos model sequential fine-tuning training") 267 | print(f"Experiment name: {self.config.experiment_name}") 268 | print(f"Experiment description: {self.config.experiment_description}") 269 | 270 | self._setup_distributed() 271 | 272 | self._create_directories() 273 | 274 | tokenizer_exists, basemodel_exists = self._check_existing_models() 275 | 276 | total_start_time = time.time() 277 | 278 | try: 279 | if self.config.train_tokenizer: 280 | success = self.train_tokenizer_phase() 281 | if not success: 282 | print("Tokenizer training failed, terminating training") 283 | return False 284 | else: 285 | print("Skipping Tokenizer training phase") 286 | 287 | if self.config.train_basemodel: 288 | success = self.train_basemodel_phase() 289 | if not success: 290 | print("Basemodel training failed, terminating training") 291 | return False 292 | else: 293 | print("Skipping Basemodel training phase") 294 | 295 | total_time = time.time() - total_start_time 296 | 297 | if self.rank == 0: 298 | print("\n" + "="*60) 299 | print("Training completed!") 300 | print("="*60) 301 | print(f"Total training time: {total_time/60:.2f} minutes") 302 | print(f"Tokenizer model: {self.config.tokenizer_best_model_path}") 303 | print(f"Basemodel model: {self.config.basemodel_best_model_path}") 304 | print("="*60) 305 | 306 | return True 307 | 308 | except Exception as e: 309 | if self.rank == 0: 310 | print(f"Error occurred during training: {str(e)}") 311 | import traceback 312 | traceback.print_exc() 313 | return False 314 | 315 | finally: 316 | pass 317 | 318 | 319 | def main(): 320 | parser = argparse.ArgumentParser(description='Kronos Model Sequential Fine-tuning Training') 321 | parser.add_argument('--config', type=str, default='config.yaml', 322 | help='Configuration file path (default: config.yaml)') 323 | parser.add_argument('--skip-tokenizer', action='store_true', 324 | help='Skip tokenizer training phase') 325 | parser.add_argument('--skip-basemodel', action='store_true', 326 | help='Skip basemodel training phase') 327 | parser.add_argument('--skip-existing', action='store_true', 328 | help='Skip training for existing models') 329 | 330 | args = parser.parse_args() 331 | 332 | trainer = SequentialTrainer(args.config) 333 | 334 | if args.skip_tokenizer: 335 | trainer.config.train_tokenizer = False 336 | if args.skip_basemodel: 337 | trainer.config.train_basemodel = False 338 | if args.skip_existing: 339 | trainer.config.skip_existing = True 340 | 341 | success = trainer.run_training() 342 | 343 | if success: 344 | print("Training completed successfully!") 345 | if dist.is_available() and dist.is_initialized(): 346 | dist.barrier() 347 | dist.destroy_process_group() 348 | sys.exit(0) 349 | else: 350 | print("Training failed!") 351 | if dist.is_available() and dist.is_initialized(): 352 | try: 353 | dist.barrier() 354 | dist.destroy_process_group() 355 | except Exception: 356 | pass 357 | sys.exit(1) 358 | 359 | 360 | if __name__ == "__main__": 361 | main() 362 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Kronos: A Foundation Model for the Language of Financial Markets

3 |
4 | 5 | 6 |
7 | 8 | 9 | 10 | Hugging Face 11 | 12 | Live Demo 13 | 14 | Last Commit 15 | 16 | 17 | GitHub Stars 18 | 19 | 20 | GitHub Forks 21 | 22 | 23 | License 24 | 25 | 26 |
27 | 28 |
29 | 30 | Deutsch | 31 | Español | 32 | Français | 33 | 日本語 | 34 | 한국어 | 35 | Português | 36 | Русский | 37 | 中文 38 |
39 | 40 |

41 | 42 | 43 | 44 |

45 | 46 | > Kronos is the **first open-source foundation model** for financial candlesticks (K-lines), 47 | > trained on data from over **45 global exchanges**. 48 | 49 | 50 | 51 | 52 | ## 📰 News 53 | * 🚩 **[2025.08.17]** We have released the scripts for fine-tuning! Check them out to adapt Kronos to your own tasks. 54 | * 🚩 **[2025.08.02]** Our paper is now available on [arXiv](https://arxiv.org/abs/2508.02739)! 55 | 56 |

57 | 58 | ## 📜 Introduction 59 | 60 | **Kronos** is a family of decoder-only foundation models, pre-trained specifically for the "language" of financial markets—K-line sequences. Unlike general-purpose TSFMs, Kronos is designed to handle the unique, high-noise characteristics of financial data. It leverages a novel two-stage framework: 61 | 1. A specialized tokenizer first quantizes continuous, multi-dimensional K-line data (OHLCV) into **hierarchical discrete tokens**. 62 | 2. A large, autoregressive Transformer is then pre-trained on these tokens, enabling it to serve as a unified model for diverse quantitative tasks. 63 | 64 |

65 | 66 |

67 | 68 | ## ✨ Live Demo 69 | We have set up a live demo to visualize Kronos's forecasting results. The webpage showcases a forecast for the **BTC/USDT** trading pair over the next 24 hours. 70 | 71 | **👉 [Access the Live Demo Here](https://shiyu-coder.github.io/Kronos-demo/)** 72 | 73 | ## 📦 Model Zoo 74 | We release a family of pre-trained models with varying capacities to suit different computational and application needs. All models are readily accessible from the Hugging Face Hub. 75 | 76 | | Model | Tokenizer | Context length | Params | Open-source | 77 | |--------------|---------------------------------------------------------------------------------| -------------- | ------ |---------------------------------------------------------------------------| 78 | | Kronos-mini | [Kronos-Tokenizer-2k](https://huggingface.co/NeoQuasar/Kronos-Tokenizer-2k) | 2048 | 4.1M | ✅ [NeoQuasar/Kronos-mini](https://huggingface.co/NeoQuasar/Kronos-mini) | 79 | | Kronos-small | [Kronos-Tokenizer-base](https://huggingface.co/NeoQuasar/Kronos-Tokenizer-base) | 512 | 24.7M | ✅ [NeoQuasar/Kronos-small](https://huggingface.co/NeoQuasar/Kronos-small) | 80 | | Kronos-base | [Kronos-Tokenizer-base](https://huggingface.co/NeoQuasar/Kronos-Tokenizer-base) | 512 | 102.3M | ✅ [NeoQuasar/Kronos-base](https://huggingface.co/NeoQuasar/Kronos-base) | 81 | | Kronos-large | [Kronos-Tokenizer-base](https://huggingface.co/NeoQuasar/Kronos-Tokenizer-base) | 512 | 499.2M | ❌ | 82 | 83 | 84 | ## 🚀 Getting Started 85 | 86 | ### Installation 87 | 88 | 1. Install Python 3.10+, and then install the dependencies: 89 | 90 | ```shell 91 | pip install -r requirements.txt 92 | ``` 93 | 94 | ### 📈 Making Forecasts 95 | 96 | Forecasting with Kronos is straightforward using the `KronosPredictor` class. It handles data preprocessing, normalization, prediction, and inverse normalization, allowing you to get from raw data to forecasts in just a few lines of code. 97 | 98 | **Important Note**: The `max_context` for `Kronos-small` and `Kronos-base` is **512**. This is the maximum sequence length the model can process. For optimal performance, it is recommended that your input data length (i.e., `lookback`) does not exceed this limit. The `KronosPredictor` will automatically handle truncation for longer contexts. 99 | 100 | Here is a step-by-step guide to making your first forecast. 101 | 102 | #### 1. Load the Tokenizer and Model 103 | 104 | First, load a pre-trained Kronos model and its corresponding tokenizer from the Hugging Face Hub. 105 | 106 | ```python 107 | from model import Kronos, KronosTokenizer, KronosPredictor 108 | 109 | # Load from Hugging Face Hub 110 | tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") 111 | model = Kronos.from_pretrained("NeoQuasar/Kronos-small") 112 | ``` 113 | 114 | #### 2. Instantiate the Predictor 115 | 116 | Create an instance of `KronosPredictor`, passing the model, tokenizer, and desired device. 117 | 118 | ```python 119 | # Initialize the predictor 120 | predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) 121 | ``` 122 | 123 | #### 3. Prepare Input Data 124 | 125 | The `predict` method requires three main inputs: 126 | - `df`: A pandas DataFrame containing the historical K-line data. It must include columns `['open', 'high', 'low', 'close']`. `volume` and `amount` are optional. 127 | - `x_timestamp`: A pandas Series of timestamps corresponding to the historical data in `df`. 128 | - `y_timestamp`: A pandas Series of timestamps for the future periods you want to predict. 129 | 130 | ```python 131 | import pandas as pd 132 | 133 | # Load your data 134 | df = pd.read_csv("./data/XSHG_5min_600977.csv") 135 | df['timestamps'] = pd.to_datetime(df['timestamps']) 136 | 137 | # Define context window and prediction length 138 | lookback = 400 139 | pred_len = 120 140 | 141 | # Prepare inputs for the predictor 142 | x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']] 143 | x_timestamp = df.loc[:lookback-1, 'timestamps'] 144 | y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps'] 145 | ``` 146 | 147 | #### 4. Generate Forecasts 148 | 149 | Call the `predict` method to generate forecasts. You can control the sampling process with parameters like `T`, `top_p`, and `sample_count` for probabilistic forecasting. 150 | 151 | ```python 152 | # Generate predictions 153 | pred_df = predictor.predict( 154 | df=x_df, 155 | x_timestamp=x_timestamp, 156 | y_timestamp=y_timestamp, 157 | pred_len=pred_len, 158 | T=1.0, # Temperature for sampling 159 | top_p=0.9, # Nucleus sampling probability 160 | sample_count=1 # Number of forecast paths to generate and average 161 | ) 162 | 163 | print("Forecasted Data Head:") 164 | print(pred_df.head()) 165 | ``` 166 | 167 | The `predict` method returns a pandas DataFrame containing the forecasted values for `open`, `high`, `low`, `close`, `volume`, and `amount`, indexed by the `y_timestamp` you provided. 168 | 169 | For efficient processing of multiple time series, Kronos provides a `predict_batch` method that enables parallel prediction on multiple datasets simultaneously. This is particularly useful when you need to forecast multiple assets or time periods at once. 170 | 171 | ```python 172 | # Prepare multiple datasets for batch prediction 173 | df_list = [df1, df2, df3] # List of DataFrames 174 | x_timestamp_list = [x_ts1, x_ts2, x_ts3] # List of historical timestamps 175 | y_timestamp_list = [y_ts1, y_ts2, y_ts3] # List of future timestamps 176 | 177 | # Generate batch predictions 178 | pred_df_list = predictor.predict_batch( 179 | df_list=df_list, 180 | x_timestamp_list=x_timestamp_list, 181 | y_timestamp_list=y_timestamp_list, 182 | pred_len=pred_len, 183 | T=1.0, 184 | top_p=0.9, 185 | sample_count=1, 186 | verbose=True 187 | ) 188 | 189 | # pred_df_list contains prediction results in the same order as input 190 | for i, pred_df in enumerate(pred_df_list): 191 | print(f"Predictions for series {i}:") 192 | print(pred_df.head()) 193 | ``` 194 | 195 | **Important Requirements for Batch Prediction:** 196 | - All series must have the same historical length (lookback window) 197 | - All series must have the same prediction length (`pred_len`) 198 | - Each DataFrame must contain the required columns: `['open', 'high', 'low', 'close']` 199 | - `volume` and `amount` columns are optional and will be filled with zeros if missing 200 | 201 | The `predict_batch` method leverages GPU parallelism for efficient processing and automatically handles normalization and denormalization for each series independently. 202 | 203 | #### 5. Example and Visualization 204 | 205 | For a complete, runnable script that includes data loading, prediction, and plotting, please see [`examples/prediction_example.py`](examples/prediction_example.py). 206 | 207 | Running this script will generate a plot comparing the ground truth data against the model's forecast, similar to the one shown below: 208 | 209 |

210 | Forecast Example 211 |

212 | 213 | Additionally, we provide a script that makes predictions without Volume and Amount data, which can be found in [`examples/prediction_wo_vol_example.py`](examples/prediction_wo_vol_example.py). 214 | 215 | 216 | ## 🔧 Finetuning on Your Own Data (A-Share Market Example) 217 | 218 | We provide a complete pipeline for finetuning Kronos on your own datasets. As an example, we demonstrate how to use [Qlib](https://github.com/microsoft/qlib) to prepare data from the Chinese A-share market and conduct a simple backtest. 219 | 220 | > **Disclaimer:** This pipeline is intended as a demonstration to illustrate the finetuning process. It is a simplified example and not a production-ready quantitative trading system. A robust quantitative strategy requires more sophisticated techniques, such as portfolio optimization and risk factor neutralization, to achieve stable alpha. 221 | 222 | The finetuning process is divided into four main steps: 223 | 224 | 1. **Configuration**: Set up paths and hyperparameters. 225 | 2. **Data Preparation**: Process and split your data using Qlib. 226 | 3. **Model Finetuning**: Finetune the Tokenizer and the Predictor models. 227 | 4. **Backtesting**: Evaluate the finetuned model's performance. 228 | 229 | ### Prerequisites 230 | 231 | 1. First, ensure you have all dependencies from `requirements.txt` installed. 232 | 2. This pipeline relies on `qlib`. Please install it: 233 | ```shell 234 | pip install pyqlib 235 | ``` 236 | 3. You will need to prepare your Qlib data. Follow the [official Qlib guide](https://github.com/microsoft/qlib) to download and set up your data locally. The example scripts assume you are using daily frequency data. 237 | 238 | ### Step 1: Configure Your Experiment 239 | 240 | All settings for data, training, and model paths are centralized in `finetune/config.py`. Before running any scripts, please **modify the following paths** according to your environment: 241 | 242 | * `qlib_data_path`: Path to your local Qlib data directory. 243 | * `dataset_path`: Directory where the processed train/validation/test pickle files will be saved. 244 | * `save_path`: Base directory for saving model checkpoints. 245 | * `backtest_result_path`: Directory for saving backtesting results. 246 | * `pretrained_tokenizer_path` and `pretrained_predictor_path`: Paths to the pre-trained models you want to start from (can be local paths or Hugging Face model names). 247 | 248 | You can also adjust other parameters like `instrument`, `train_time_range`, `epochs`, and `batch_size` to fit your specific task. If you don't use [Comet.ml](https://www.comet.com/), set `use_comet = False`. 249 | 250 | ### Step 2: Prepare the Dataset 251 | 252 | Run the data preprocessing script. This script will load raw market data from your Qlib directory, process it, split it into training, validation, and test sets, and save them as pickle files. 253 | 254 | ```shell 255 | python finetune/qlib_data_preprocess.py 256 | ``` 257 | 258 | After running, you will find `train_data.pkl`, `val_data.pkl`, and `test_data.pkl` in the directory specified by `dataset_path` in your config. 259 | 260 | ### Step 3: Run the Finetuning 261 | 262 | The finetuning process consists of two stages: finetuning the tokenizer and then the predictor. Both training scripts are designed for multi-GPU training using `torchrun`. 263 | 264 | #### 3.1 Finetune the Tokenizer 265 | 266 | This step adjusts the tokenizer to the data distribution of your specific domain. 267 | 268 | ```shell 269 | # Replace NUM_GPUS with the number of GPUs you want to use (e.g., 2) 270 | torchrun --standalone --nproc_per_node=NUM_GPUS finetune/train_tokenizer.py 271 | ``` 272 | 273 | The best tokenizer checkpoint will be saved to the path configured in `config.py` (derived from `save_path` and `tokenizer_save_folder_name`). 274 | 275 | #### 3.2 Finetune the Predictor 276 | 277 | This step finetunes the main Kronos model for the forecasting task. 278 | 279 | ```shell 280 | # Replace NUM_GPUS with the number of GPUs you want to use (e.g., 2) 281 | torchrun --standalone --nproc_per_node=NUM_GPUS finetune/train_predictor.py 282 | ``` 283 | 284 | The best predictor checkpoint will be saved to the path configured in `config.py`. 285 | 286 | ### Step 4: Evaluate with Backtesting 287 | 288 | Finally, run the backtesting script to evaluate your finetuned model. This script loads the models, performs inference on the test set, generates prediction signals (e.g., forecasted price change), and runs a simple top-K strategy backtest. 289 | 290 | ```shell 291 | # Specify the GPU for inference 292 | python finetune/qlib_test.py --device cuda:0 293 | ``` 294 | 295 | The script will output a detailed performance analysis in your console and generate a plot showing the cumulative return curves of your strategy against the benchmark, similar to the one below: 296 | 297 |

298 | Backtest Example 299 |

300 | 301 | ### 💡 From Demo to Production: Important Considerations 302 | 303 | * **Raw Signals vs. Pure Alpha**: The signals generated by the model in this demo are raw predictions. In a real-world quantitative workflow, these signals would typically be fed into a portfolio optimization model. This model would apply constraints to neutralize exposure to common risk factors (e.g., market beta, style factors like size and value), thereby isolating the **"pure alpha"** and improving the strategy's robustness. 304 | * **Data Handling**: The provided `QlibDataset` is an example. For different data sources or formats, you will need to adapt the data loading and preprocessing logic. 305 | * **Strategy and Backtesting Complexity**: The simple top-K strategy used here is a basic starting point. Production-level strategies often incorporate more complex logic for portfolio construction, dynamic position sizing, and risk management (e.g., stop-loss/take-profit rules). Furthermore, a high-fidelity backtest should meticulously model transaction costs, slippage, and market impact to provide a more accurate estimate of real-world performance. 306 | 307 | > **📝 AI-Generated Comments**: Please note that many of the code comments within the `finetune/` directory were generated by an AI assistant (Gemini 2.5 Pro) for explanatory purposes. While they aim to be helpful, they may contain inaccuracies. We recommend treating the code itself as the definitive source of logic. 308 | 309 | ## 📖 Citation 310 | 311 | If you use Kronos in your research, we would appreciate a citation to our [paper](https://arxiv.org/abs/2508.02739): 312 | 313 | ``` 314 | @misc{shi2025kronos, 315 | title={Kronos: A Foundation Model for the Language of Financial Markets}, 316 | author={Yu Shi and Zongliang Fu and Shuo Chen and Bohan Zhao and Wei Xu and Changshui Zhang and Jian Li}, 317 | year={2025}, 318 | eprint={2508.02739}, 319 | archivePrefix={arXiv}, 320 | primaryClass={q-fin.ST}, 321 | url={https://arxiv.org/abs/2508.02739}, 322 | } 323 | ``` 324 | 325 | ## 📜 License 326 | This project is licensed under the [MIT License](./LICENSE). 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | -------------------------------------------------------------------------------- /finetune_csv/finetune_base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | import pickle 6 | import random 7 | import pandas as pd 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.data import Dataset, DataLoader 12 | from torch.utils.data.distributed import DistributedSampler 13 | from time import gmtime, strftime 14 | import logging 15 | from logging.handlers import RotatingFileHandler 16 | import datetime 17 | import torch.distributed as dist 18 | from torch.nn.parallel import DistributedDataParallel as DDP 19 | 20 | sys.path.append('../') 21 | from model import Kronos, KronosTokenizer, KronosPredictor 22 | from config_loader import CustomFinetuneConfig 23 | 24 | 25 | class CustomKlineDataset(Dataset): 26 | 27 | def __init__(self, data_path, data_type='train', lookback_window=90, predict_window=10, 28 | clip=5.0, seed=100, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15): 29 | self.data_path = data_path 30 | self.data_type = data_type 31 | self.lookback_window = lookback_window 32 | self.predict_window = predict_window 33 | self.window = lookback_window + predict_window + 1 34 | self.clip = clip 35 | self.seed = seed 36 | self.train_ratio = train_ratio 37 | self.val_ratio = val_ratio 38 | self.test_ratio = test_ratio 39 | 40 | self.feature_list = ['open', 'high', 'low', 'close', 'volume', 'amount'] 41 | self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month'] 42 | 43 | self.py_rng = random.Random(seed) 44 | 45 | self._load_and_preprocess_data() 46 | self._split_data_by_time() 47 | 48 | self.n_samples = len(self.data) - self.window + 1 49 | 50 | print(f"[{data_type.upper()}] Data length: {len(self.data)}, Available samples: {self.n_samples}") 51 | 52 | def _load_and_preprocess_data(self): 53 | df = pd.read_csv(self.data_path) 54 | 55 | df['timestamps'] = pd.to_datetime(df['timestamps']) 56 | df = df.sort_values('timestamps').reset_index(drop=True) 57 | 58 | self.timestamps = df['timestamps'].copy() 59 | 60 | df['minute'] = df['timestamps'].dt.minute 61 | df['hour'] = df['timestamps'].dt.hour 62 | df['weekday'] = df['timestamps'].dt.weekday 63 | df['day'] = df['timestamps'].dt.day 64 | df['month'] = df['timestamps'].dt.month 65 | 66 | self.data = df[self.feature_list + self.time_feature_list].copy() 67 | 68 | if self.data.isnull().any().any(): 69 | print("Warning: Missing values found in data, performing forward fill") 70 | self.data = self.data.fillna(method='ffill') 71 | 72 | print(f"Original data time range: {self.timestamps.min()} to {self.timestamps.max()}") 73 | print(f"Original data total length: {len(df)} records") 74 | 75 | def _split_data_by_time(self): 76 | total_length = len(self.data) 77 | 78 | train_end = int(total_length * self.train_ratio) 79 | val_end = int(total_length * (self.train_ratio + self.val_ratio)) 80 | 81 | if self.data_type == 'train': 82 | self.data = self.data.iloc[:train_end].copy() 83 | self.timestamps = self.timestamps.iloc[:train_end].copy() 84 | print(f"[{self.data_type.upper()}] Training set: first {train_end} time points ({self.train_ratio})") 85 | print(f"[{self.data_type.upper()}] Training set time range: {self.timestamps.min()} to {self.timestamps.max()}") 86 | elif self.data_type == 'val': 87 | self.data = self.data.iloc[train_end:val_end].copy() 88 | self.timestamps = self.timestamps.iloc[train_end:val_end].copy() 89 | print(f"[{self.data_type.upper()}] Validation set: time points {train_end+1} to {val_end} ({self.val_ratio})") 90 | print(f"[{self.data_type.upper()}] Validation set time range: {self.timestamps.min()} to {self.timestamps.max()}") 91 | elif self.data_type == 'test': 92 | self.data = self.data.iloc[val_end:].copy() 93 | self.timestamps = self.timestamps.iloc[val_end:].copy() 94 | print(f"[{self.data_type.upper()}] Test set: after time point {val_end+1}") 95 | print(f"[{self.data_type.upper()}] Test set time range: {self.timestamps.min()} to {self.timestamps.max()}") 96 | 97 | print(f"[{self.data_type.upper()}] Data length after split: {len(self.data)} records") 98 | 99 | def set_epoch_seed(self, epoch): 100 | epoch_seed = self.seed + epoch 101 | self.py_rng.seed(epoch_seed) 102 | self.current_epoch = epoch 103 | 104 | def __len__(self): 105 | return self.n_samples 106 | 107 | def __getitem__(self, idx): 108 | max_start = len(self.data) - self.window 109 | if max_start <= 0: 110 | raise ValueError("Data length insufficient to create samples") 111 | 112 | if self.data_type == 'train': 113 | epoch = getattr(self, 'current_epoch', 0) 114 | start_idx = (idx * 9973 + (epoch + 1) * 104729) % (max_start + 1) 115 | else: 116 | start_idx = idx % (max_start + 1) 117 | 118 | end_idx = start_idx + self.window 119 | 120 | window_data = self.data.iloc[start_idx:end_idx] 121 | 122 | x = window_data[self.feature_list].values.astype(np.float32) 123 | x_stamp = window_data[self.time_feature_list].values.astype(np.float32) 124 | 125 | x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) 126 | x = (x - x_mean) / (x_std + 1e-5) 127 | x = np.clip(x, -self.clip, self.clip) 128 | 129 | x_tensor = torch.from_numpy(x) 130 | x_stamp_tensor = torch.from_numpy(x_stamp) 131 | 132 | return x_tensor, x_stamp_tensor 133 | 134 | 135 | 136 | 137 | def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger: 138 | os.makedirs(log_dir, exist_ok=True) 139 | 140 | logger = logging.getLogger(f"basemodel_training_rank_{rank}") 141 | logger.setLevel(logging.INFO) 142 | 143 | if logger.handlers: 144 | return logger 145 | 146 | log_file = os.path.join(log_dir, f"basemodel_training_rank_{rank}.log") 147 | file_handler = RotatingFileHandler( 148 | log_file, 149 | maxBytes=10*1024*1024, 150 | backupCount=5, 151 | encoding='utf-8' 152 | ) 153 | file_handler.setLevel(logging.INFO) 154 | 155 | console_handler = None 156 | if rank == 0: 157 | console_handler = logging.StreamHandler() 158 | console_handler.setLevel(logging.INFO) 159 | 160 | formatter = logging.Formatter( 161 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s', 162 | datefmt='%Y-%m-%d %H:%M:%S' 163 | ) 164 | file_handler.setFormatter(formatter) 165 | if console_handler is not None: 166 | console_handler.setFormatter(formatter) 167 | 168 | logger.addHandler(file_handler) 169 | if console_handler is not None: 170 | logger.addHandler(console_handler) 171 | 172 | logger.info(f"=== Basemodel Training Started ===") 173 | logger.info(f"Experiment Name: {exp_name}") 174 | logger.info(f"Log Directory: {log_dir}") 175 | logger.info(f"Rank: {rank}") 176 | logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") 177 | 178 | return logger 179 | 180 | 181 | def create_dataloaders(config): 182 | if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: 183 | print("Creating data loaders...") 184 | 185 | train_dataset = CustomKlineDataset( 186 | data_path=config.data_path, 187 | data_type='train', 188 | lookback_window=config.lookback_window, 189 | predict_window=config.predict_window, 190 | clip=config.clip, 191 | seed=config.seed, 192 | train_ratio=config.train_ratio, 193 | val_ratio=config.val_ratio, 194 | test_ratio=config.test_ratio 195 | ) 196 | 197 | val_dataset = CustomKlineDataset( 198 | data_path=config.data_path, 199 | data_type='val', 200 | lookback_window=config.lookback_window, 201 | predict_window=config.predict_window, 202 | clip=config.clip, 203 | seed=config.seed + 1, 204 | train_ratio=config.train_ratio, 205 | val_ratio=config.val_ratio, 206 | test_ratio=config.test_ratio 207 | ) 208 | 209 | use_ddp = dist.is_available() and dist.is_initialized() 210 | train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None 211 | val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None 212 | 213 | train_loader = DataLoader( 214 | train_dataset, 215 | batch_size=config.batch_size, 216 | shuffle=(train_sampler is None), 217 | num_workers=config.num_workers, 218 | pin_memory=True, 219 | drop_last=True, 220 | sampler=train_sampler 221 | ) 222 | 223 | val_loader = DataLoader( 224 | val_dataset, 225 | batch_size=config.batch_size, 226 | shuffle=False, 227 | num_workers=config.num_workers, 228 | pin_memory=True, 229 | drop_last=False, 230 | sampler=val_sampler 231 | ) 232 | 233 | if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: 234 | print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}") 235 | 236 | return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler 237 | 238 | 239 | def train_model(model, tokenizer, device, config, save_dir, logger): 240 | logger.info("Starting training...") 241 | use_ddp = dist.is_available() and dist.is_initialized() 242 | rank = dist.get_rank() if use_ddp else 0 243 | world_size = dist.get_world_size() if use_ddp else 1 244 | 245 | train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config) 246 | optimizer = torch.optim.AdamW( 247 | model.parameters(), 248 | lr=config.predictor_learning_rate, 249 | betas=(config.adam_beta1, config.adam_beta2), 250 | weight_decay=config.adam_weight_decay 251 | ) 252 | 253 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 254 | optimizer, 255 | max_lr=config.predictor_learning_rate, 256 | steps_per_epoch=len(train_loader), 257 | epochs=config.basemodel_epochs, 258 | pct_start=0.03, 259 | div_factor=10 260 | ) 261 | 262 | if use_ddp: 263 | local_rank = int(os.environ.get("LOCAL_RANK", "0")) 264 | model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) 265 | 266 | best_val_loss = float('inf') 267 | batch_idx_global = 0 268 | 269 | for epoch in range(config.basemodel_epochs): 270 | epoch_start_time = time.time() 271 | model.train() 272 | 273 | train_dataset.set_epoch_seed(epoch * 10000) 274 | val_dataset.set_epoch_seed(0) 275 | if train_sampler is not None: 276 | train_sampler.set_epoch(epoch) 277 | 278 | epoch_train_loss = 0.0 279 | train_batches = 0 280 | 281 | for batch_idx, (batch_x, batch_x_stamp) in enumerate(train_loader): 282 | batch_x = batch_x.to(device, non_blocking=True) 283 | batch_x_stamp = batch_x_stamp.to(device, non_blocking=True) 284 | 285 | with torch.no_grad(): 286 | token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) 287 | 288 | token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] 289 | token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] 290 | 291 | logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) 292 | loss, s1_loss, s2_loss = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) 293 | 294 | optimizer.zero_grad() 295 | loss.backward() 296 | torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=3.0) 297 | optimizer.step() 298 | scheduler.step() 299 | 300 | epoch_train_loss += loss.item() 301 | train_batches += 1 302 | 303 | if (batch_idx_global + 1) % config.log_interval == 0: 304 | lr = optimizer.param_groups[0]['lr'] 305 | log_msg = (f"[Epoch {epoch+1}/{config.basemodel_epochs}, Step {batch_idx+1}/{len(train_loader)}] " 306 | f"LR: {lr:.6f}, Loss: {loss.item():.4f}") 307 | logger.info(log_msg) 308 | if rank == 0: 309 | print(log_msg) 310 | 311 | batch_idx_global += 1 312 | 313 | model.eval() 314 | val_loss = 0.0 315 | val_batches = 0 316 | 317 | with torch.no_grad(): 318 | for batch_x, batch_x_stamp in val_loader: 319 | batch_x = batch_x.to(device, non_blocking=True) 320 | batch_x_stamp = batch_x_stamp.to(device, non_blocking=True) 321 | 322 | token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) 323 | token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] 324 | token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] 325 | 326 | logits = (model.module if use_ddp else model)(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) 327 | loss, _, _ = (model.module if use_ddp else model).head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) 328 | 329 | val_loss += loss.item() 330 | val_batches += 1 331 | 332 | if use_ddp: 333 | tensor_sum = torch.tensor([epoch_train_loss, train_batches, val_loss, val_batches], dtype=torch.float64, device=device) 334 | dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM) 335 | epoch_train_loss_all = tensor_sum[0].item() 336 | train_batches_all = int(tensor_sum[1].item()) 337 | val_loss_all = tensor_sum[2].item() 338 | val_batches_all = int(tensor_sum[3].item()) 339 | avg_train_loss = (epoch_train_loss_all / train_batches_all) if train_batches_all > 0 else 0.0 340 | avg_val_loss = (val_loss_all / val_batches_all) if val_batches_all > 0 else 0.0 341 | else: 342 | avg_train_loss = epoch_train_loss / train_batches if train_batches > 0 else 0 343 | avg_val_loss = val_loss / val_batches if val_batches > 0 else 0 344 | 345 | epoch_time = time.time() - epoch_start_time 346 | epoch_summary = (f"\n--- Epoch {epoch+1}/{config.basemodel_epochs} Summary ---\n" 347 | f"Training Loss: {avg_train_loss:.4f}\n" 348 | f"Validation Loss: {avg_val_loss:.4f}\n" 349 | f"Epoch Time: {epoch_time:.2f} seconds\n") 350 | logger.info(epoch_summary) 351 | if rank == 0: 352 | print(epoch_summary) 353 | 354 | if avg_val_loss < best_val_loss: 355 | best_val_loss = avg_val_loss 356 | if rank == 0: 357 | model_save_path = os.path.join(save_dir, "best_model") 358 | os.makedirs(model_save_path, exist_ok=True) 359 | (model.module if use_ddp else model).save_pretrained(model_save_path) 360 | save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})" 361 | logger.info(save_msg) 362 | print(save_msg) 363 | 364 | return best_val_loss 365 | 366 | 367 | def main(): 368 | import argparse 369 | 370 | parser = argparse.ArgumentParser(description='Kronos Basemodel Fine-tuning Training') 371 | parser.add_argument('--config', type=str, default='config.yaml', 372 | help='Configuration file path (default: config.yaml)') 373 | args = parser.parse_args() 374 | 375 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 376 | print(f"Using device: {device}") 377 | 378 | config = CustomFinetuneConfig(args.config) 379 | 380 | os.makedirs(config.basemodel_save_path, exist_ok=True) 381 | 382 | log_dir = os.path.join(config.base_save_path, "logs") 383 | logger = setup_logging(config.exp_name, log_dir, 0) 384 | 385 | torch.manual_seed(config.seed) 386 | np.random.seed(config.seed) 387 | random.seed(config.seed) 388 | 389 | logger.info("Loading pretrained model or random initialization...") 390 | print("Loading pretrained model or random initialization...") 391 | if getattr(config, 'pre_trained_tokenizer', True): 392 | tokenizer = KronosTokenizer.from_pretrained(config.finetuned_tokenizer_path) 393 | else: 394 | import json, os 395 | print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for training") 396 | cfg_path_tok = os.path.join(config.pretrained_tokenizer_path if hasattr(config, 'pretrained_tokenizer_path') else config.finetuned_tokenizer_path, 'config.json') 397 | with open(cfg_path_tok, 'r') as f: 398 | arch_t = json.load(f) 399 | tokenizer = KronosTokenizer( 400 | d_in=arch_t.get('d_in', 6), 401 | d_model=arch_t.get('d_model', 256), 402 | n_heads=arch_t.get('n_heads', 4), 403 | ff_dim=arch_t.get('ff_dim', 512), 404 | n_enc_layers=arch_t.get('n_enc_layers', 4), 405 | n_dec_layers=arch_t.get('n_dec_layers', 4), 406 | ffn_dropout_p=arch_t.get('ffn_dropout_p', 0.0), 407 | attn_dropout_p=arch_t.get('attn_dropout_p', 0.0), 408 | resid_dropout_p=arch_t.get('resid_dropout_p', 0.0), 409 | s1_bits=arch_t.get('s1_bits', 10), 410 | s2_bits=arch_t.get('s2_bits', 10), 411 | beta=arch_t.get('beta', 0.05), 412 | gamma0=arch_t.get('gamma0', 1.0), 413 | gamma=arch_t.get('gamma', 1.1), 414 | zeta=arch_t.get('zeta', 0.05), 415 | group_size=arch_t.get('group_size', 4) 416 | ) 417 | 418 | if getattr(config, 'pre_trained_predictor', True): 419 | model = Kronos.from_pretrained(config.pretrained_predictor_path) 420 | else: 421 | import json, os 422 | print("pre_trained_predictor=False, randomly initializing Predictor architecture for training") 423 | cfg_path = os.path.join(config.pretrained_predictor_path, 'config.json') 424 | with open(cfg_path, 'r') as f: 425 | arch = json.load(f) 426 | model = Kronos( 427 | s1_bits=arch.get('s1_bits', 10), 428 | s2_bits=arch.get('s2_bits', 10), 429 | n_layers=arch.get('n_layers', 12), 430 | d_model=arch.get('d_model', 832), 431 | n_heads=arch.get('n_heads', 16), 432 | ff_dim=arch.get('ff_dim', 2048), 433 | ffn_dropout_p=arch.get('ffn_dropout_p', 0.2), 434 | attn_dropout_p=arch.get('attn_dropout_p', 0.0), 435 | resid_dropout_p=arch.get('resid_dropout_p', 0.2), 436 | token_dropout_p=arch.get('token_dropout_p', 0.0), 437 | learn_te=arch.get('learn_te', True) 438 | ) 439 | 440 | tokenizer = tokenizer.to(device) 441 | model = model.to(device) 442 | 443 | model_size = sum(p.numel() for p in model.parameters()) 444 | logger.info(f"Model parameters: {model_size:,}") 445 | print(f"Model parameters: {model_size:,}") 446 | 447 | logger.info("=== Training Configuration ===") 448 | logger.info(f"Data path: {config.data_path}") 449 | logger.info(f"Lookback window: {config.lookback_window}") 450 | logger.info(f"Predict window: {config.predict_window}") 451 | logger.info(f"Batch size: {config.batch_size}") 452 | logger.info(f"Learning rate: {config.predictor_learning_rate}") 453 | logger.info(f"Training epochs: {config.basemodel_epochs}") 454 | logger.info(f"Device: {device}") 455 | logger.info(f"Tokenizer path: {config.finetuned_tokenizer_path}") 456 | logger.info(f"Pretrained model path: {config.pretrained_predictor_path}") 457 | 458 | logger.info("Starting fine-tuning training...") 459 | print("Starting fine-tuning training...") 460 | best_val_loss = train_model(model, tokenizer, device, config, config.basemodel_save_path, logger) 461 | 462 | final_msg = f"Training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.basemodel_save_path}" 463 | logger.info(final_msg) 464 | print(final_msg) 465 | 466 | 467 | if __name__ == "__main__": 468 | main() 469 | -------------------------------------------------------------------------------- /model/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from einops import rearrange, reduce 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | import torch.nn.functional as F 8 | 9 | 10 | class DifferentiableEntropyFunction(Function): 11 | @staticmethod 12 | def forward(ctx, zq, basis, K, eps): 13 | zb = (zq + 1) / 2 14 | zi = ((zb * basis).sum(-1)).to(torch.int64) 15 | cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype), 16 | 0, 17 | zi.flatten(), 18 | torch.ones_like(zi.flatten()).to(zq.dtype), 19 | 'sum') 20 | prob = (cnt + eps) / (cnt + eps).sum() 21 | H = -(prob * torch.log(prob)).sum() 22 | ctx.save_for_backward(zq, zi, prob) 23 | ctx.K = K 24 | return H 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | zq, zi, prob = ctx.saved_tensors 29 | grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K 30 | reord_grad = grad_array[zi.flatten()].reshape(zi.shape) 31 | grad_input = reord_grad.unsqueeze(-1) * zq 32 | return grad_input, None, None, None, None 33 | 34 | 35 | def codebook_entropy(zq, basis, K, eps=1e-4): 36 | return DifferentiableEntropyFunction.apply(zq, basis, K, eps) 37 | 38 | 39 | class BinarySphericalQuantizer(nn.Module): 40 | def __init__(self, embed_dim, beta, gamma0, gamma, zeta, 41 | input_format='bchw', 42 | soft_entropy=True, group_size=9, 43 | persample_entropy_compute='analytical', 44 | cb_entropy_compute='group', 45 | l2_norm=True, 46 | inv_temperature=1): 47 | """ 48 | Paper link: https://arxiv.org/pdf/2406.07548.pdf 49 | Here we use the official implementation of the BinarySphericalQuantizer. 50 | """ 51 | super().__init__() 52 | self.embed_dim = embed_dim 53 | self.beta = beta # loss weight for commit loss 54 | self.gamma0 = gamma0 # loss weight for entropy penalty 55 | self.gamma = gamma # loss weight for entropy penalty 56 | self.zeta = zeta # loss weight for entire entropy penalty 57 | self.input_format = input_format 58 | assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size" 59 | self.num_groups = self.embed_dim // group_size 60 | self.group_size = group_size 61 | assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'" 62 | assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'" 63 | self.persample_entropy_compute = persample_entropy_compute 64 | self.cb_entropy_compute = cb_entropy_compute 65 | self.l2_norm = l2_norm 66 | self.inv_temperature = inv_temperature 67 | 68 | self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1)) 69 | self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1)) 70 | 71 | self.num_dimensions = 2 ** embed_dim 72 | self.bits_per_index = embed_dim 73 | 74 | # we only need to keep the codebook portion up to the group size 75 | # because we approximate the H loss with this subcode 76 | group_codes = torch.arange(2 ** self.group_size) 77 | group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:] 78 | self.register_buffer('group_codebook', group_codebook, persistent=False) 79 | 80 | self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf 81 | 82 | def quantize(self, z): 83 | assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}" 84 | 85 | zhat = torch.where(z > 0, 86 | torch.tensor(1, dtype=z.dtype, device=z.device), 87 | torch.tensor(-1, dtype=z.dtype, device=z.device)) 88 | return z + (zhat - z).detach() 89 | 90 | def forward(self, z): 91 | # if self.input_format == 'bchw': 92 | # z = rearrange(z, 'b c h w -> b h w c') 93 | zq = self.quantize(z) 94 | 95 | indices = self.codes_to_indexes(zq.detach()) 96 | group_indices = self.codes_to_group_indexes(zq.detach()) 97 | if not self.training: 98 | used_codes = torch.unique(indices, return_counts=False) 99 | else: 100 | used_codes = None 101 | 102 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. 103 | 104 | if self.soft_entropy: 105 | persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z) 106 | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy 107 | else: 108 | zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32) 109 | persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample) 110 | cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim) 111 | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy 112 | 113 | zq = zq * q_scale 114 | 115 | # commit loss 116 | commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1)) 117 | 118 | # if self.input_format == 'bchw': 119 | # zq = rearrange(zq, 'b h w c -> b c h w') 120 | 121 | return ( 122 | zq, 123 | commit_loss + self.zeta * entropy_penalty / self.inv_temperature, 124 | {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices, 125 | "avg_prob": avg_prob} 126 | ) 127 | 128 | def soft_entropy_loss(self, z): 129 | # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size 130 | # the sub-code is the last group_size bits of the full code 131 | group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1) 132 | divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size) 133 | 134 | # we calculate the distance between the divided_z and the codebook for each subgroup 135 | distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book) 136 | prob = (-distance * self.inv_temperature).softmax(dim=-1) 137 | if self.persample_entropy_compute == 'analytical': 138 | if self.l2_norm: 139 | p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature) 140 | else: 141 | p = torch.sigmoid(-4 * z * self.inv_temperature) 142 | prob = torch.stack([p, 1 - p], dim=-1) 143 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() 144 | else: 145 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() 146 | 147 | # macro average of the probability of each subgroup 148 | avg_prob = reduce(prob, '... g d ->g d', 'mean') 149 | codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) 150 | 151 | # the approximation of the entropy is the sum of the entropy of each subgroup 152 | return per_sample_entropy, codebook_entropy.sum(), avg_prob 153 | 154 | def get_hard_per_sample_entropy(self, zb_by_sample): 155 | probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1] 156 | persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8) 157 | persample_entropy = persample_entropy.sum(-1) 158 | return persample_entropy.mean() 159 | 160 | def codes_to_indexes(self, zhat): 161 | """Converts a `code` to an index in the codebook. 162 | Args: 163 | zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} 164 | """ 165 | assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}" 166 | return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64) 167 | 168 | def codes_to_group_indexes(self, zhat): 169 | """Converts a `code` to a list of indexes (in groups) in the codebook. 170 | Args: 171 | zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} 172 | """ 173 | zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size) 174 | return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64) 175 | 176 | def indexes_to_codes(self, indices): 177 | """Inverse of `indexes_to_codes`.""" 178 | indices = indices.unsqueeze(-1) 179 | codes_non_centered = torch.remainder( 180 | torch.floor_divide(indices, self.basis), 2 181 | ) 182 | return codes_non_centered * 2 - 1 183 | 184 | def group_indexes_to_codes(self, group_indices): 185 | """Inverse of `group_indexes_to_codes`.""" 186 | group_indices = group_indices.unsqueeze(-1) 187 | codes_non_centered = torch.remainder( 188 | torch.floor_divide(group_indices, self.group_basis), 2 189 | ) 190 | codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)') 191 | return codes_non_centered * 2 - 1 192 | 193 | def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): 194 | if normalize: 195 | probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True) 196 | else: 197 | probs = count 198 | H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) 199 | return H 200 | 201 | def get_group_codebook_entry(self, group_indices): 202 | z_q = self.group_indexes_to_codes(group_indices) 203 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. 204 | z_q = z_q * q_scale 205 | if self.input_format == 'bchw': 206 | h, w = int(z_q.shape[1] ** 0.5) 207 | assert h * w == z_q.shape[1], 'Invalid sequence length' 208 | z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) 209 | return z_q 210 | 211 | def get_codebook_entry(self, indices): 212 | z_q = self.indexes_to_codes(indices) 213 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. 214 | z_q = z_q * q_scale 215 | if self.input_format == 'bchw': 216 | h, w = int(z_q.shape[1] ** 0.5) 217 | assert h * w == z_q.shape[1], 'Invalid sequence length' 218 | z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) 219 | return z_q 220 | 221 | 222 | class BSQuantizer(nn.Module): 223 | 224 | def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): 225 | super().__init__() 226 | self.codebook_dim = s1_bits + s2_bits 227 | self.s1_bits = s1_bits 228 | self.s2_bits = s2_bits 229 | self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size) 230 | 231 | def bits_to_indices(self, bits): 232 | bits = (bits >= 0).to(torch.long) 233 | indices = 2 ** torch.arange( 234 | 0, 235 | bits.shape[-1], 236 | 1, 237 | dtype=torch.long, 238 | device=bits.device, 239 | ) 240 | return (bits * indices).sum(-1) 241 | 242 | def forward(self, z, half=False): 243 | z = F.normalize(z, dim=-1) 244 | quantized, bsq_loss, metrics = self.bsq(z) 245 | if half: 246 | q_pre = quantized[:, :, :self.s1_bits] 247 | q_post = quantized[:, :, self.s1_bits:] 248 | z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)] 249 | else: 250 | z_indices = self.bits_to_indices(quantized) 251 | return bsq_loss, quantized, z_indices 252 | 253 | 254 | class RMSNorm(torch.nn.Module): 255 | def __init__(self, dim: int, eps: float = 1e-5): 256 | super().__init__() 257 | self.eps = eps 258 | self.weight = nn.Parameter(torch.ones(dim)) 259 | 260 | def _norm(self, x): 261 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 262 | 263 | def forward(self, x): 264 | output = self._norm(x.float()).type_as(x) 265 | return output * self.weight 266 | 267 | 268 | class FeedForward(nn.Module): 269 | def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0): 270 | super().__init__() 271 | 272 | self.w1 = nn.Linear(d_model, ff_dim, bias=False) 273 | self.w3 = nn.Linear(d_model, ff_dim, bias=False) 274 | self.w2 = nn.Linear(ff_dim, d_model, bias=False) 275 | self.ffn_dropout = nn.Dropout(ffn_dropout_p) 276 | 277 | def forward(self, x): 278 | return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) 279 | 280 | 281 | class RotaryPositionalEmbedding(nn.Module): 282 | def __init__(self, dim): 283 | super().__init__() 284 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 285 | self.register_buffer("inv_freq", inv_freq) 286 | self.seq_len_cached = None 287 | self.cos_cached = None 288 | self.sin_cached = None 289 | 290 | def _update_cos_sin_cache(self, x, seq_len): 291 | if seq_len != self.seq_len_cached: 292 | self.seq_len_cached = seq_len 293 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) 294 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 295 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 296 | self.cos_cached = emb.cos()[None, None, :, :] 297 | self.sin_cached = emb.sin()[None, None, :, :] 298 | return self.cos_cached, self.sin_cached 299 | 300 | def forward(self, q, k): 301 | cos, sin = self._update_cos_sin_cache(q, q.shape[-2]) 302 | return ( 303 | (q * cos) + (self._rotate_half(q) * sin), 304 | (k * cos) + (self._rotate_half(k) * sin), 305 | ) 306 | 307 | def _rotate_half(self, x): 308 | x1, x2 = x.chunk(2, dim=-1) 309 | return torch.cat((-x2, x1), dim=-1) 310 | 311 | 312 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, training=True) -> torch.Tensor: 313 | L, S = query.size(-2), key.size(-2) 314 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 315 | attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device) 316 | 317 | if is_causal: 318 | assert attn_mask is None 319 | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(query.device) 320 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 321 | attn_bias.to(query.dtype) 322 | 323 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 324 | attn_weight += attn_bias 325 | 326 | if attn_mask is not None: 327 | attn_mask_bias = torch.zeros_like(attn_weight) 328 | if attn_mask.dtype == torch.bool: 329 | attn_mask_bias.masked_fill_(attn_mask, float("-inf")) 330 | else: 331 | attn_mask_bias += attn_mask 332 | attn_weight += attn_mask_bias 333 | 334 | attn_weight = torch.softmax(attn_weight, dim=-1) 335 | attn_weight = torch.dropout(attn_weight, dropout_p, train=training) 336 | return attn_weight @ value 337 | 338 | 339 | class MultiHeadAttentionWithRoPE(nn.Module): 340 | def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0): 341 | super().__init__() 342 | self.d_model = d_model 343 | self.n_heads = n_heads 344 | self.head_dim = d_model // n_heads 345 | 346 | self.q_proj = nn.Linear(d_model, d_model) 347 | self.k_proj = nn.Linear(d_model, d_model) 348 | self.v_proj = nn.Linear(d_model, d_model) 349 | self.out_proj = nn.Linear(d_model, d_model) 350 | self.rotary = RotaryPositionalEmbedding(self.head_dim) 351 | self.attn_dropout_p = attn_dropout_p 352 | self.resid_dropout = nn.Dropout(resid_dropout_p) 353 | 354 | def forward(self, x, key_padding_mask=None): 355 | batch_size, seq_len, _ = x.shape 356 | 357 | q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 358 | k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 359 | v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 360 | 361 | q, k = self.rotary(q, k) 362 | 363 | if key_padding_mask is not None: 364 | attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len] 365 | attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len] 366 | else: 367 | attn_mask = None 368 | 369 | attn_output = scaled_dot_product_attention( 370 | q, k, v, 371 | attn_mask=attn_mask, 372 | dropout_p=self.attn_dropout_p, 373 | is_causal=True, 374 | training=self.training 375 | ) 376 | 377 | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) 378 | return self.resid_dropout(self.out_proj(attn_output)) 379 | 380 | 381 | class MultiHeadCrossAttentionWithRoPE(nn.Module): 382 | def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0): 383 | super().__init__() 384 | self.d_model = d_model 385 | self.n_heads = n_heads 386 | self.head_dim = d_model // n_heads 387 | 388 | self.q_proj = nn.Linear(d_model, d_model) 389 | self.k_proj = nn.Linear(d_model, d_model) 390 | self.v_proj = nn.Linear(d_model, d_model) 391 | self.out_proj = nn.Linear(d_model, d_model) 392 | self.rotary = RotaryPositionalEmbedding(self.head_dim) 393 | self.attn_dropout_p = attn_dropout_p 394 | self.resid_dropout = nn.Dropout(resid_dropout) 395 | 396 | def forward(self, query, key, value, key_padding_mask=None): 397 | batch_size, q_len, _ = query.shape 398 | _, seq_len, _ = key.shape 399 | 400 | q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2) 401 | k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 402 | v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 403 | 404 | q, k = self.rotary(q, k) 405 | 406 | if key_padding_mask is not None: 407 | attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) 408 | attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1) 409 | else: 410 | attn_mask = None 411 | 412 | is_causal_flag = self.training 413 | 414 | attn_output = scaled_dot_product_attention( 415 | q, k, v, 416 | attn_mask=attn_mask, 417 | dropout_p=self.attn_dropout_p, 418 | is_causal=is_causal_flag, 419 | training=self.training 420 | ) 421 | 422 | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model) 423 | return self.resid_dropout(self.out_proj(attn_output)) 424 | 425 | 426 | class HierarchicalEmbedding(nn.Module): 427 | def __init__(self, s1_bits, s2_bits, d_model=256): 428 | super().__init__() 429 | self.s1_bits = s1_bits 430 | self.s2_bits = s2_bits 431 | 432 | vocab_s1 = 2 ** s1_bits 433 | vocab_s2 = 2 ** s2_bits 434 | 435 | self.emb_s1 = nn.Embedding(vocab_s1, d_model) 436 | self.emb_s2 = nn.Embedding(vocab_s2, d_model) 437 | self.d_model = d_model 438 | self.fusion_proj = nn.Linear(d_model * 2, d_model) 439 | 440 | nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5) 441 | nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5) 442 | 443 | def forward(self, token_ids): 444 | """Inputs: 445 | token_ids: [batch_size, seq_len] token ID 446 | Output: [batch_size, seq_len, d_model] 447 | """ 448 | if isinstance(token_ids, tuple) or isinstance(token_ids, list): 449 | s1_ids, s2_ids = token_ids 450 | else: 451 | s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits) 452 | s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model) 453 | s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model) 454 | return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1)) 455 | 456 | 457 | class DependencyAwareLayer(nn.Module): 458 | def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0): 459 | super().__init__() 460 | self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout) 461 | self.norm = RMSNorm(d_model) 462 | 463 | def forward(self, hidden_states, sibling_embed, key_padding_mask=None): 464 | """hidden_states: [batch, seq_len, d_model] 465 | sibling_embed: Embedding from another subtoken 466 | """ 467 | attn_out = self.cross_attn( 468 | query=sibling_embed, 469 | key=hidden_states, 470 | value=hidden_states, 471 | key_padding_mask=key_padding_mask 472 | ) 473 | return self.norm(hidden_states + attn_out) 474 | 475 | 476 | class TransformerBlock(nn.Module): 477 | def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0): 478 | super().__init__() 479 | self.norm1 = RMSNorm(d_model) 480 | self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p) 481 | self.norm2 = RMSNorm(d_model) 482 | self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p) 483 | 484 | def forward(self, x, key_padding_mask=None): 485 | residual = x 486 | x = self.norm1(x) 487 | attn_out = self.self_attn(x, key_padding_mask=key_padding_mask) 488 | x = residual + attn_out 489 | 490 | residual = x 491 | x = self.norm2(x) 492 | ffn_out = self.ffn(x) 493 | x = residual + ffn_out 494 | return x 495 | 496 | 497 | class DualHead(nn.Module): 498 | def __init__(self, s1_bits, s2_bits, d_model): 499 | super().__init__() 500 | self.vocab_s1 = 2 ** s1_bits 501 | self.vocab_s2 = 2 ** s2_bits 502 | self.proj_s1 = nn.Linear(d_model, self.vocab_s1) 503 | self.proj_s2 = nn.Linear(d_model, self.vocab_s2) 504 | 505 | def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None): 506 | if padding_mask is not None: 507 | valid_mask = (padding_mask == 0) 508 | s1_logits = s1_logits[valid_mask] 509 | s2_logits = s2_logits[valid_mask] 510 | s1_targets = s1_targets[valid_mask] 511 | s2_targets = s2_targets[valid_mask] 512 | ce_s1 = F.cross_entropy(s1_logits, s1_targets) 513 | ce_s2 = F.cross_entropy(s2_logits, s2_targets) 514 | else: 515 | ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1)) 516 | ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1)) 517 | ce_loss = (ce_s1 + ce_s2) / 2 518 | return ce_loss, ce_s1, ce_s2 519 | 520 | def forward(self, x): 521 | return self.proj_s1(x) 522 | 523 | def cond_forward(self, x2): 524 | return self.proj_s2(x2) 525 | 526 | 527 | class FixedEmbedding(nn.Module): 528 | def __init__(self, c_in, d_model): 529 | super(FixedEmbedding, self).__init__() 530 | 531 | w = torch.zeros(c_in, d_model).float() 532 | w.require_grad = False 533 | 534 | position = torch.arange(0, c_in).float().unsqueeze(1) 535 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 536 | 537 | w[:, 0::2] = torch.sin(position * div_term) 538 | w[:, 1::2] = torch.cos(position * div_term) 539 | 540 | self.emb = nn.Embedding(c_in, d_model) 541 | self.emb.weight = nn.Parameter(w, requires_grad=False) 542 | 543 | def forward(self, x): 544 | return self.emb(x).detach() 545 | 546 | 547 | class TemporalEmbedding(nn.Module): 548 | def __init__(self, d_model, learn_pe): 549 | super(TemporalEmbedding, self).__init__() 550 | 551 | minute_size = 60 552 | hour_size = 24 553 | weekday_size = 7 554 | day_size = 32 555 | month_size = 13 556 | 557 | Embed = FixedEmbedding if not learn_pe else nn.Embedding 558 | self.minute_embed = Embed(minute_size, d_model) 559 | self.hour_embed = Embed(hour_size, d_model) 560 | self.weekday_embed = Embed(weekday_size, d_model) 561 | self.day_embed = Embed(day_size, d_model) 562 | self.month_embed = Embed(month_size, d_model) 563 | 564 | def forward(self, x): 565 | x = x.long() 566 | 567 | minute_x = self.minute_embed(x[:, :, 0]) 568 | hour_x = self.hour_embed(x[:, :, 1]) 569 | weekday_x = self.weekday_embed(x[:, :, 2]) 570 | day_x = self.day_embed(x[:, :, 3]) 571 | month_x = self.month_embed(x[:, :, 4]) 572 | 573 | return hour_x + weekday_x + day_x + month_x + minute_x 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | 582 | --------------------------------------------------------------------------------