├── finetune ├── utils │ ├── __init__.py │ └── training_utils.py ├── qlib_data_preprocess.py ├── dataset.py ├── config.py ├── train_predictor.py ├── train_tokenizer.py └── qlib_test.py ├── figures ├── logo.png ├── overview.png ├── prediction_example.png └── backtest_result_example.png ├── webui ├── requirements.txt ├── start.sh ├── run.py └── README.md ├── requirements.txt ├── finetune_csv ├── examples │ ├── HK_ali_09988_kline_5min_all_historical_20250919_073929.png │ ├── HK_ali_09988_kline_5min_all_historical_20250919_073944.png │ ├── HK_ali_09988_kline_5min_all_historical_20250919_074012.png │ ├── HK_ali_09988_kline_5min_all_historical_20250919_074042.png │ └── HK_ali_09988_kline_5min_all_historical_20250919_074251.png ├── configs │ └── config_ali09988_candle-5min.yaml ├── README_CN.md ├── README.md ├── config_loader.py ├── finetune_tokenizer.py ├── train_sequential.py └── finetune_base_model.py ├── model ├── __init__.py └── module.py ├── .gitignore ├── LICENSE ├── examples ├── prediction_wo_vol_example.py ├── prediction_batch_example.py └── prediction_example.py └── README.md /finetune/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/figures/logo.png -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/figures/overview.png -------------------------------------------------------------------------------- /figures/prediction_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/figures/prediction_example.png -------------------------------------------------------------------------------- /figures/backtest_result_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/figures/backtest_result_example.png -------------------------------------------------------------------------------- /webui/requirements.txt: -------------------------------------------------------------------------------- 1 | flask==2.3.3 2 | flask-cors==4.0.0 3 | pandas==2.2.2 4 | numpy==1.24.3 5 | plotly==5.17.0 6 | torch>=2.1.0 7 | huggingface_hub==0.33.1 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | torch 4 | 5 | einops==0.8.1 6 | huggingface_hub==0.33.1 7 | matplotlib==3.9.3 8 | pandas==2.2.2 9 | tqdm==4.67.1 10 | safetensors==0.6.2 11 | -------------------------------------------------------------------------------- /finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_073929.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_073929.png -------------------------------------------------------------------------------- /finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_073944.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_073944.png -------------------------------------------------------------------------------- /finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_074012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_074012.png -------------------------------------------------------------------------------- /finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_074042.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_074042.png -------------------------------------------------------------------------------- /finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_074251.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mon2026/Kronos/HEAD/finetune_csv/examples/HK_ali_09988_kline_5min_all_historical_20250919_074251.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .kronos import KronosTokenizer, Kronos, KronosPredictor 2 | 3 | model_dict = { 4 | 'kronos_tokenizer': KronosTokenizer, 5 | 'kronos': Kronos, 6 | 'kronos_predictor': KronosPredictor 7 | } 8 | 9 | 10 | def get_model_class(model_name): 11 | if model_name in model_dict: 12 | return model_dict[model_name] 13 | else: 14 | print(f"Model {model_name} not found in model_dict") 15 | raise NotImplementedError 16 | 17 | 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | MANIFEST 23 | 24 | # Jupyter Notebook 25 | .ipynb_checkpoints 26 | 27 | # PyCharm 28 | .idea/ 29 | 30 | # VS Code 31 | .vscode/ 32 | 33 | # macOS 34 | .DS_Store 35 | .AppleDouble 36 | .LSOverride 37 | 38 | # Windows 39 | Thumbs.db 40 | ehthumbs.db 41 | Desktop.ini 42 | 43 | # Linux 44 | *~ 45 | 46 | # Data files (large files) 47 | *.feather 48 | *.parquet 49 | *.h5 50 | *.hdf5 51 | 52 | # Model files (large files) 53 | *.pth 54 | *.pt 55 | *.ckpt 56 | *.bin 57 | 58 | # Logs 59 | *.log 60 | logs/ 61 | 62 | # Environment 63 | .env 64 | .venv 65 | env/ 66 | venv/ 67 | ENV/ 68 | env.bak/ 69 | venv.bak/ 70 | 71 | # Temporary files 72 | *.tmp 73 | *.temp 74 | temp/ 75 | tmp/ 76 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 ShiYu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /webui/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Kronos Web UI startup script 4 | 5 | echo "🚀 Starting Kronos Web UI..." 6 | echo "================================" 7 | 8 | # Check if Python is installed 9 | if ! command -v python3 &> /dev/null; then 10 | echo "❌ Python3 not installed, please install Python3 first" 11 | exit 1 12 | fi 13 | 14 | # Check if in correct directory 15 | if [ ! -f "app.py" ]; then 16 | echo "❌ Please run this script in the webui directory" 17 | exit 1 18 | fi 19 | 20 | # Check dependencies 21 | echo "📦 Checking dependencies..." 22 | if ! python3 -c "import flask, flask_cors, pandas, numpy, plotly" &> /dev/null; then 23 | echo "⚠️ Missing dependencies, installing..." 24 | pip3 install -r requirements.txt 25 | if [ $? -ne 0 ]; then 26 | echo "❌ Dependencies installation failed" 27 | exit 1 28 | fi 29 | echo "✅ Dependencies installation completed" 30 | else 31 | echo "✅ All dependencies installed" 32 | fi 33 | 34 | # Start application 35 | echo "🌐 Starting Web server..." 36 | echo "Access URL: http://localhost:7070" 37 | echo "Press Ctrl+C to stop server" 38 | echo "" 39 | 40 | python3 app.py 41 | -------------------------------------------------------------------------------- /examples/prediction_wo_vol_example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import sys 4 | sys.path.append("../") 5 | from model import Kronos, KronosTokenizer, KronosPredictor 6 | 7 | 8 | def plot_prediction(kline_df, pred_df): 9 | pred_df.index = kline_df.index[-pred_df.shape[0]:] 10 | sr_close = kline_df['close'] 11 | sr_pred_close = pred_df['close'] 12 | sr_close.name = 'Ground Truth' 13 | sr_pred_close.name = "Prediction" 14 | 15 | close_df = pd.concat([sr_close, sr_pred_close], axis=1) 16 | 17 | fig, ax = plt.subplots(1, 1, figsize=(8, 4)) 18 | 19 | ax.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 20 | ax.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 21 | ax.set_ylabel('Close Price', fontsize=14) 22 | ax.legend(loc='lower left', fontsize=12) 23 | ax.grid(True) 24 | 25 | plt.tight_layout() 26 | plt.show() 27 | 28 | 29 | # 1. Load Model and Tokenizer 30 | tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") 31 | model = Kronos.from_pretrained("NeoQuasar/Kronos-small") 32 | 33 | # 2. Instantiate Predictor 34 | predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) 35 | 36 | # 3. Prepare Data 37 | df = pd.read_csv("./data/XSHG_5min_600977.csv") 38 | df['timestamps'] = pd.to_datetime(df['timestamps']) 39 | 40 | lookback = 400 41 | pred_len = 120 42 | 43 | x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close']] 44 | x_timestamp = df.loc[:lookback-1, 'timestamps'] 45 | y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps'] 46 | 47 | # 4. Make Prediction 48 | pred_df = predictor.predict( 49 | df=x_df, 50 | x_timestamp=x_timestamp, 51 | y_timestamp=y_timestamp, 52 | pred_len=pred_len, 53 | T=1.0, 54 | top_p=0.9, 55 | sample_count=1, 56 | verbose=True 57 | ) 58 | 59 | # 5. Visualize Results 60 | print("Forecasted Data Head:") 61 | print(pred_df.head()) 62 | 63 | # Combine historical and forecasted data for plotting 64 | kline_df = df.loc[:lookback+pred_len-1] 65 | 66 | # visualize 67 | plot_prediction(kline_df, pred_df) 68 | 69 | -------------------------------------------------------------------------------- /finetune_csv/configs/config_ali09988_candle-5min.yaml: -------------------------------------------------------------------------------- 1 | #This is a template config for custom finetuning kronos on csv data 2 | #这是一份模板config,用于kronos的csv自定义数据微调 3 | 4 | data: 5 | data_path: "/xxxx/Kronos/finetune_csv/data/HK_ali_09988_kline_5min_all.csv" 6 | lookback_window: 512 7 | predict_window: 48 8 | max_context: 512 9 | clip: 5.0 10 | # dataset split ratio 11 | train_ratio: 0.9 12 | val_ratio: 0.1 13 | test_ratio: 0.0 14 | 15 | training: 16 | # control the training epochs of tokenizer and basemodel 17 | tokenizer_epochs: 30 18 | basemodel_epochs: 20 19 | batch_size: 32 20 | log_interval: 50 21 | num_workers: 6 22 | seed: 42 23 | 24 | tokenizer_learning_rate: 0.0002 25 | predictor_learning_rate: 0.000001 26 | 27 | adam_beta1: 0.9 28 | adam_beta2: 0.95 29 | adam_weight_decay: 0.1 30 | 31 | # gradient accumulation steps for tokenizer training 32 | accumulation_steps: 1 33 | 34 | # model path configuration 35 | model_paths: 36 | # pretrained model path 37 | pretrained_tokenizer: "/xxx/Kronos/pretrained/Kronos-Tokenizer-base" 38 | pretrained_predictor: "/xxx/Kronos/pretrained/Kronos-base" 39 | 40 | # experiment name - other paths will be generated based on this 41 | exp_name: "HK_ali_09988_kline_5min_all" 42 | base_path: "/xxx/Kronos/finetune_csv/finetuned/" 43 | 44 | # the following paths will be generated based on exp_name, no need to modify manually 45 | # way 1: leave empty string, the system will generate the full path 46 | base_save_path: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name} 47 | finetuned_tokenizer: "" # /xxxx/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model 48 | 49 | # way 2: use template string, {exp_name} will be replaced with the actual experiment name 50 | # base_save_path: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}" 51 | # finetuned_tokenizer: "/xxxx/Kronos/finetune_csv/finetuned/{exp_name}/tokenizer/best_model" 52 | 53 | tokenizer_save_name: "tokenizer" 54 | basemodel_save_name: "basemodel" 55 | 56 | experiment: 57 | name: "kronos_custom_finetune" 58 | description: "Custom finetune for HK stock data" 59 | use_comet: false 60 | 61 | # control the training phase 62 | train_tokenizer: true 63 | train_basemodel: true 64 | 65 | # if true, skip the existing model training 66 | skip_existing: false 67 | 68 | # device configuration 69 | device: 70 | use_cuda: true 71 | device_id: 0 72 | 73 | -------------------------------------------------------------------------------- /examples/prediction_batch_example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import sys 4 | sys.path.append("../") 5 | from model import Kronos, KronosTokenizer, KronosPredictor 6 | 7 | 8 | def plot_prediction(kline_df, pred_df): 9 | pred_df.index = kline_df.index[-pred_df.shape[0]:] 10 | sr_close = kline_df['close'] 11 | sr_pred_close = pred_df['close'] 12 | sr_close.name = 'Ground Truth' 13 | sr_pred_close.name = "Prediction" 14 | 15 | sr_volume = kline_df['volume'] 16 | sr_pred_volume = pred_df['volume'] 17 | sr_volume.name = 'Ground Truth' 18 | sr_pred_volume.name = "Prediction" 19 | 20 | close_df = pd.concat([sr_close, sr_pred_close], axis=1) 21 | volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1) 22 | 23 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True) 24 | 25 | ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 26 | ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 27 | ax1.set_ylabel('Close Price', fontsize=14) 28 | ax1.legend(loc='lower left', fontsize=12) 29 | ax1.grid(True) 30 | 31 | ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 32 | ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 33 | ax2.set_ylabel('Volume', fontsize=14) 34 | ax2.legend(loc='upper left', fontsize=12) 35 | ax2.grid(True) 36 | 37 | plt.tight_layout() 38 | plt.show() 39 | 40 | 41 | # 1. Load Model and Tokenizer 42 | tokenizer = KronosTokenizer.from_pretrained('/home/csc/huggingface/Kronos-Tokenizer-base/') 43 | model = Kronos.from_pretrained("/home/csc/huggingface/Kronos-base/") 44 | 45 | # 2. Instantiate Predictor 46 | predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) 47 | 48 | # 3. Prepare Data 49 | df = pd.read_csv("./data/XSHG_5min_600977.csv") 50 | df['timestamps'] = pd.to_datetime(df['timestamps']) 51 | 52 | lookback = 400 53 | pred_len = 120 54 | 55 | dfs = [] 56 | xtsp = [] 57 | ytsp = [] 58 | for i in range(5): 59 | idf = df.loc[(i*400):(i*400+lookback-1), ['open', 'high', 'low', 'close', 'volume', 'amount']] 60 | i_x_timestamp = df.loc[(i*400):(i*400+lookback-1), 'timestamps'] 61 | i_y_timestamp = df.loc[(i*400+lookback):(i*400+lookback+pred_len-1), 'timestamps'] 62 | 63 | dfs.append(idf) 64 | xtsp.append(i_x_timestamp) 65 | ytsp.append(i_y_timestamp) 66 | 67 | pred_df = predictor.predict_batch( 68 | df_list=dfs, 69 | x_timestamp_list=xtsp, 70 | y_timestamp_list=ytsp, 71 | pred_len=pred_len, 72 | ) 73 | -------------------------------------------------------------------------------- /examples/prediction_example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import sys 4 | sys.path.append("../") 5 | from model import Kronos, KronosTokenizer, KronosPredictor 6 | 7 | 8 | def plot_prediction(kline_df, pred_df): 9 | pred_df.index = kline_df.index[-pred_df.shape[0]:] 10 | sr_close = kline_df['close'] 11 | sr_pred_close = pred_df['close'] 12 | sr_close.name = 'Ground Truth' 13 | sr_pred_close.name = "Prediction" 14 | 15 | sr_volume = kline_df['volume'] 16 | sr_pred_volume = pred_df['volume'] 17 | sr_volume.name = 'Ground Truth' 18 | sr_pred_volume.name = "Prediction" 19 | 20 | close_df = pd.concat([sr_close, sr_pred_close], axis=1) 21 | volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1) 22 | 23 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True) 24 | 25 | ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 26 | ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 27 | ax1.set_ylabel('Close Price', fontsize=14) 28 | ax1.legend(loc='lower left', fontsize=12) 29 | ax1.grid(True) 30 | 31 | ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 32 | ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 33 | ax2.set_ylabel('Volume', fontsize=14) 34 | ax2.legend(loc='upper left', fontsize=12) 35 | ax2.grid(True) 36 | 37 | plt.tight_layout() 38 | plt.show() 39 | 40 | 41 | # 1. Load Model and Tokenizer 42 | tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") 43 | model = Kronos.from_pretrained("NeoQuasar/Kronos-small") 44 | 45 | # 2. Instantiate Predictor 46 | predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) 47 | 48 | # 3. Prepare Data 49 | df = pd.read_csv("./data/XSHG_5min_600977.csv") 50 | df['timestamps'] = pd.to_datetime(df['timestamps']) 51 | 52 | lookback = 400 53 | pred_len = 120 54 | 55 | x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']] 56 | x_timestamp = df.loc[:lookback-1, 'timestamps'] 57 | y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps'] 58 | 59 | # 4. Make Prediction 60 | pred_df = predictor.predict( 61 | df=x_df, 62 | x_timestamp=x_timestamp, 63 | y_timestamp=y_timestamp, 64 | pred_len=pred_len, 65 | T=1.0, 66 | top_p=0.9, 67 | sample_count=1, 68 | verbose=True 69 | ) 70 | 71 | # 5. Visualize Results 72 | print("Forecasted Data Head:") 73 | print(pred_df.head()) 74 | 75 | # Combine historical and forecasted data for plotting 76 | kline_df = df.loc[:lookback+pred_len-1] 77 | 78 | # visualize 79 | plot_prediction(kline_df, pred_df) 80 | 81 | -------------------------------------------------------------------------------- /webui/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Kronos Web UI startup script 4 | """ 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import webbrowser 10 | import time 11 | 12 | def check_dependencies(): 13 | """Check if dependencies are installed""" 14 | try: 15 | import flask 16 | import flask_cors 17 | import pandas 18 | import numpy 19 | import plotly 20 | print("✅ All dependencies installed") 21 | return True 22 | except ImportError as e: 23 | print(f"❌ Missing dependency: {e}") 24 | print("Please run: pip install -r requirements.txt") 25 | return False 26 | 27 | def install_dependencies(): 28 | """Install dependencies""" 29 | print("Installing dependencies...") 30 | try: 31 | subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"]) 32 | print("✅ Dependencies installation completed") 33 | return True 34 | except subprocess.CalledProcessError: 35 | print("❌ Dependencies installation failed") 36 | return False 37 | 38 | def main(): 39 | """Main function""" 40 | print("🚀 Starting Kronos Web UI...") 41 | print("=" * 50) 42 | 43 | # Check dependencies 44 | if not check_dependencies(): 45 | print("\nAuto-install dependencies? (y/n): ", end="") 46 | if input().lower() == 'y': 47 | if not install_dependencies(): 48 | return 49 | else: 50 | print("Please manually install dependencies and retry") 51 | return 52 | 53 | # Check model availability 54 | try: 55 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 56 | from model import Kronos, KronosTokenizer, KronosPredictor 57 | print("✅ Kronos model library available") 58 | model_available = True 59 | except ImportError: 60 | print("⚠️ Kronos model library not available, will use simulated prediction") 61 | model_available = False 62 | 63 | # Start Flask application 64 | print("\n🌐 Starting Web server...") 65 | 66 | # Set environment variables 67 | os.environ['FLASK_APP'] = 'app.py' 68 | os.environ['FLASK_ENV'] = 'development' 69 | 70 | # Start server 71 | try: 72 | from app import app 73 | print("✅ Web server started successfully!") 74 | print(f"🌐 Access URL: http://localhost:7070") 75 | print("💡 Tip: Press Ctrl+C to stop server") 76 | 77 | # Auto-open browser 78 | time.sleep(2) 79 | webbrowser.open('http://localhost:7070') 80 | 81 | # Start Flask application 82 | app.run(debug=True, host='0.0.0.0', port=7070) 83 | 84 | except Exception as e: 85 | print(f"❌ Startup failed: {e}") 86 | print("Please check if port 7070 is occupied") 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /finetune_csv/README_CN.md: -------------------------------------------------------------------------------- 1 | # Kronos微调-支持自定义CSV数据集 2 | 3 | 这是一个在自定义的CSV格式数据上微调Kronos模型的完整流程。包含顺序训练(先训练tokenizer再训练predictor)和单独模块训练,同时支持分布式训练。 4 | 5 | 6 | ## 1. 准备数据 7 | 8 | ### 数据格式 9 | 10 | CSV文件必须包含以下列: 11 | - `timestamps`: 每个数据点的时间戳 12 | - `open`: 开盘价 13 | - `high`: 最高价 14 | - `low`: 最低价 15 | - `close`: 收盘价 16 | - `volume`: 交易量 17 | - `amount`: 交易金额 18 | 19 | (volume和amount可以全0如果没有这部分的数据) 20 | 21 | ### 示例数据格式 22 | 23 | | timestamps | open | close | high | low | volume | amount | 24 | |------------|------|-------|------|-----|--------|--------| 25 | | 2019/11/26 9:35 | 182.45215 | 184.45215 | 184.95215 | 182.45215 | 15136000 | 0 | 26 | | 2019/11/26 9:40 | 184.35215 | 183.85215 | 184.55215 | 183.45215 | 4433300 | 0 | 27 | | 2019/11/26 9:45 | 183.85215 | 183.35215 | 183.95215 | 182.95215 | 3070900 | 0 | 28 | 29 | > **标准数据样例**: `data/HK_ali_09988_kline_5min_all.csv` 30 | 31 | ## 2. 准备config文件 32 | 33 | data_path及预训练模型路径需要修改,训练参数可以自己调节 34 | 35 | ```yaml 36 | # 数据配置 37 | data: 38 | data_path: "/path/to/your/data.csv" 39 | lookback_window: 512 # 要使用的历史数据点 40 | predict_window: 48 # 要预测的未来点数 41 | max_context: 512 # 最大上下文长度 42 | 43 | ... 44 | 45 | ``` 46 | 这里还有其他一些设置, `configs/config_ali09988_candle-5min.yaml` 有更详细的注释。 47 | 48 | ## 3. 训练 49 | 50 | ### 方法1: 直接顺序训练 51 | 52 | `train_sequential.py` 脚本自动处理完整的训练流程: 53 | 54 | ```bash 55 | # 完整训练(tokenizer + predictor) 56 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml 57 | 58 | # 跳过已存在的模型 59 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing 60 | 61 | # 只训练tokenizer 62 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel 63 | 64 | # 只训练predictor 65 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer 66 | ``` 67 | 68 | ### 方法2: 单独组件训练 69 | 70 | 可以单独训练每个组件: 71 | 72 | ```bash 73 | # 步骤1: 训练tokenizer 74 | python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml 75 | 76 | # 步骤2: 训练predictor(需要微调后的tokenizer) 77 | python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml 78 | ``` 79 | 80 | ### DDP训练 81 | 82 | 如果有多卡,可以开启ddp加速训练: 83 | 84 | ```bash 85 | # 设置通信后端(NVIDIA GPU用nccl,CPU/混合用gloo) 86 | DIST_BACKEND=nccl \ 87 | torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml 88 | ``` 89 | 90 | ## 4. 训练结果 91 | 92 | 训练过程生成以下输出: 93 | 94 | ### 模型检查点 95 | - **Tokenizer**: 保存到 `{base_save_path}/{exp_name}/tokenizer/best_model/` 96 | - **Predictor**: 保存到 `{base_save_path}/{exp_name}/basemodel/best_model/` 97 | 98 | ### 训练日志 99 | - **控制台输出**: 实时训练进度和指标 100 | - **日志文件**: 详细日志保存到 `{base_save_path}/logs/` 101 | - **验证跟踪**: 基于验证损失保存最佳模型 102 | 103 | ## 5. 预测可视化 104 | 105 | 以下图像显示了kronos在阿里巴巴股票数据上微调后的示例训练结果: 106 | 107 |  108 | 109 |  110 | 111 |  112 | 113 |  114 | 115 |  116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /finetune/utils/training_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import datetime 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | def setup_ddp(): 10 | """ 11 | Initializes the distributed data parallel environment. 12 | 13 | This function relies on environment variables set by `torchrun` or a similar 14 | launcher. It initializes the process group and sets the CUDA device for the 15 | current process. 16 | 17 | Returns: 18 | tuple: A tuple containing (rank, world_size, local_rank). 19 | """ 20 | if not dist.is_available(): 21 | raise RuntimeError("torch.distributed is not available.") 22 | 23 | dist.init_process_group(backend="nccl") 24 | rank = int(os.environ["RANK"]) 25 | world_size = int(os.environ["WORLD_SIZE"]) 26 | local_rank = int(os.environ["LOCAL_RANK"]) 27 | torch.cuda.set_device(local_rank) 28 | print( 29 | f"[DDP Setup] Global Rank: {rank}/{world_size}, " 30 | f"Local Rank (GPU): {local_rank} on device {torch.cuda.current_device()}" 31 | ) 32 | return rank, world_size, local_rank 33 | 34 | 35 | def cleanup_ddp(): 36 | """Cleans up the distributed process group.""" 37 | if dist.is_initialized(): 38 | dist.destroy_process_group() 39 | 40 | 41 | def set_seed(seed: int, rank: int = 0): 42 | """ 43 | Sets the random seed for reproducibility across all relevant libraries. 44 | 45 | Args: 46 | seed (int): The base seed value. 47 | rank (int): The process rank, used to ensure different processes have 48 | different seeds, which can be important for data loading. 49 | """ 50 | actual_seed = seed + rank 51 | random.seed(actual_seed) 52 | np.random.seed(actual_seed) 53 | torch.manual_seed(actual_seed) 54 | if torch.cuda.is_available(): 55 | torch.cuda.manual_seed_all(actual_seed) 56 | # The two lines below can impact performance, so they are often 57 | # reserved for final experiments where reproducibility is critical. 58 | torch.backends.cudnn.deterministic = True 59 | torch.backends.cudnn.benchmark = False 60 | 61 | 62 | def get_model_size(model: torch.nn.Module) -> str: 63 | """ 64 | Calculates the number of trainable parameters in a PyTorch model and returns 65 | it as a human-readable string. 66 | 67 | Args: 68 | model (torch.nn.Module): The PyTorch model. 69 | 70 | Returns: 71 | str: A string representing the model size (e.g., "175.0B", "7.1M", "50.5K"). 72 | """ 73 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 74 | 75 | if total_params >= 1e9: 76 | return f"{total_params / 1e9:.1f}B" # Billions 77 | elif total_params >= 1e6: 78 | return f"{total_params / 1e6:.1f}M" # Millions 79 | else: 80 | return f"{total_params / 1e3:.1f}K" # Thousands 81 | 82 | 83 | def reduce_tensor(tensor: torch.Tensor, world_size: int, op=dist.ReduceOp.SUM) -> torch.Tensor: 84 | """ 85 | Reduces a tensor's value across all processes in a distributed setup. 86 | 87 | Args: 88 | tensor (torch.Tensor): The tensor to be reduced. 89 | world_size (int): The total number of processes. 90 | op (dist.ReduceOp, optional): The reduction operation (SUM, AVG, etc.). 91 | Defaults to dist.ReduceOp.SUM. 92 | 93 | Returns: 94 | torch.Tensor: The reduced tensor, which will be identical on all processes. 95 | """ 96 | rt = tensor.clone() 97 | dist.all_reduce(rt, op=op) 98 | # Note: `dist.ReduceOp.AVG` is available in newer torch versions. 99 | # For compatibility, manual division is sometimes used after a SUM. 100 | if op == dist.ReduceOp.AVG: 101 | rt /= world_size 102 | return rt 103 | 104 | 105 | def format_time(seconds: float) -> str: 106 | """ 107 | Formats a duration in seconds into a human-readable H:M:S string. 108 | 109 | Args: 110 | seconds (float): The total seconds. 111 | 112 | Returns: 113 | str: The formatted time string (e.g., "0:15:32"). 114 | """ 115 | return str(datetime.timedelta(seconds=int(seconds))) 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /finetune_csv/README.md: -------------------------------------------------------------------------------- 1 | # Kronos Fine-tuning on Custom CSV Datasets 2 | 3 | This module provides a comprehensive pipeline for fine-tuning Kronos models on your own CSV-formatted financial data. It supports both sequential training (tokenizer followed by predictor) and individual component training, with full distributed training capabilities. 4 | 5 | 6 | ## 1. Data Preparation 7 | 8 | ### Required Data Format 9 | 10 | Your CSV file must contain the following columns: 11 | - `timestamps`: DateTime stamps for each data point 12 | - `open`: Opening price 13 | - `high`: Highest price 14 | - `low`: Lowest price 15 | - `close`: Closing price 16 | - `volume`: Trading volume 17 | - `amount`: Trading amount 18 | 19 | (volume and amount can be 0 if not available) 20 | 21 | ### Sample Data Format 22 | 23 | | timestamps | open | close | high | low | volume | amount | 24 | |------------|------|-------|------|-----|--------|--------| 25 | | 2019/11/26 9:35 | 182.45215 | 184.45215 | 184.95215 | 182.45215 | 15136000 | 0 | 26 | | 2019/11/26 9:40 | 184.35215 | 183.85215 | 184.55215 | 183.45215 | 4433300 | 0 | 27 | | 2019/11/26 9:45 | 183.85215 | 183.35215 | 183.95215 | 182.95215 | 3070900 | 0 | 28 | 29 | > **Reference**: Check `data/HK_ali_09988_kline_5min_all.csv` for a complete example of the proper data format. 30 | 31 | 32 | ## 2. Config Preparation 33 | 34 | 35 | Please edit the correct data path & pretrained model path and set your training parameters. 36 | 37 | ```yaml 38 | # Data configuration 39 | data: 40 | data_path: "/path/to/your/data.csv" 41 | lookback_window: 512 # Historical data points to use 42 | predict_window: 48 # Future points to predict 43 | max_context: 512 # Maximum context length 44 | 45 | ... 46 | 47 | ``` 48 | There are some other settings here, please see `configs/config_ali09988_candle-5min.yaml` for more comments. 49 | 50 | ## 3. Training 51 | 52 | ### Method 1: Sequential Training (Recommended) 53 | 54 | The `train_sequential.py` script handles the complete training pipeline automatically: 55 | 56 | ```bash 57 | # Complete training (tokenizer + predictor) 58 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml 59 | 60 | # Skip existing models 61 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-existing 62 | 63 | # Only train tokenizer 64 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-basemodel 65 | 66 | # Only train predictor 67 | python train_sequential.py --config configs/config_ali09988_candle-5min.yaml --skip-tokenizer 68 | ``` 69 | 70 | ### Method 2: Individual Component Training 71 | 72 | Train each component separately for more control: 73 | 74 | ```bash 75 | # Step 1: Train tokenizer 76 | python finetune_tokenizer.py --config configs/config_ali09988_candle-5min.yaml 77 | 78 | # Step 2: Train predictor (requires fine-tuned tokenizer) 79 | python finetune_base_model.py --config configs/config_ali09988_candle-5min.yaml 80 | ``` 81 | 82 | ### DDP Training 83 | 84 | For faster training on multiple GPUs: 85 | 86 | ```bash 87 | # Set communication backend (nccl for NVIDIA GPUs, gloo for CPU/mixed) 88 | DIST_BACKEND=nccl \ 89 | torchrun --standalone --nproc_per_node=8 train_sequential.py --config configs/config_ali09988_candle-5min.yaml 90 | ``` 91 | 92 | ## 4. Training Results 93 | 94 | The training process generates several outputs: 95 | 96 | ### Model Checkpoints 97 | - **Tokenizer**: Saved to `{base_save_path}/{exp_name}/tokenizer/best_model/` 98 | - **Predictor**: Saved to `{base_save_path}/{exp_name}/basemodel/best_model/` 99 | 100 | ### Training Logs 101 | - **Console output**: Real-time training progress and metrics 102 | - **Log files**: Detailed logs saved to `{base_save_path}/logs/` 103 | - **Validation tracking**: Best models are saved based on validation loss 104 | 105 | ## 5. Prediction Vis 106 | 107 | The following images show example training results on alibaba (HK stock) data: 108 | 109 |  110 | 111 |  112 | 113 |  114 | 115 |  116 | 117 |  118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /webui/README.md: -------------------------------------------------------------------------------- 1 | # Kronos Web UI 2 | 3 | Web user interface for Kronos financial prediction model, providing intuitive graphical operation interface. 4 | 5 | ## ✨ Features 6 | 7 | - **Multi-format data support**: Supports CSV, Feather and other financial data formats 8 | - **Smart time window**: Fixed 400+120 data point time window slider selection 9 | - **Real model prediction**: Integrated real Kronos model, supports multiple model sizes 10 | - **Prediction quality control**: Adjustable temperature, nucleus sampling, sample count and other parameters 11 | - **Multi-device support**: Supports CPU, CUDA, MPS and other computing devices 12 | - **Comparison analysis**: Detailed comparison between prediction results and actual data 13 | - **K-line chart display**: Professional financial K-line chart display 14 | 15 | ## 🚀 Quick Start 16 | 17 | ### Method 1: Start with Python script 18 | ```bash 19 | cd webui 20 | python run.py 21 | ``` 22 | 23 | ### Method 2: Start with Shell script 24 | ```bash 25 | cd webui 26 | chmod +x start.sh 27 | ./start.sh 28 | ``` 29 | 30 | ### Method 3: Start Flask application directly 31 | ```bash 32 | cd webui 33 | python app.py 34 | ``` 35 | 36 | After successful startup, visit http://localhost:7070 37 | 38 | ## 📋 Usage Steps 39 | 40 | 1. **Load data**: Select financial data file from data directory 41 | 2. **Load model**: Select Kronos model and computing device 42 | 3. **Set parameters**: Adjust prediction quality parameters 43 | 4. **Select time window**: Use slider to select 400+120 data point time range 44 | 5. **Start prediction**: Click prediction button to generate results 45 | 6. **View results**: View prediction results in charts and tables 46 | 47 | ## 🔧 Prediction Quality Parameters 48 | 49 | ### Temperature (T) 50 | - **Range**: 0.1 - 2.0 51 | - **Effect**: Controls prediction randomness 52 | - **Recommendation**: 1.2-1.5 for better prediction quality 53 | 54 | ### Nucleus Sampling (top_p) 55 | - **Range**: 0.1 - 1.0 56 | - **Effect**: Controls prediction diversity 57 | - **Recommendation**: 0.95-1.0 to consider more possibilities 58 | 59 | ### Sample Count 60 | - **Range**: 1 - 5 61 | - **Effect**: Generate multiple prediction samples 62 | - **Recommendation**: 2-3 samples to improve quality 63 | 64 | ## 📊 Supported Data Formats 65 | 66 | ### Required Columns 67 | - `open`: Opening price 68 | - `high`: Highest price 69 | - `low`: Lowest price 70 | - `close`: Closing price 71 | 72 | ### Optional Columns 73 | - `volume`: Trading volume 74 | - `amount`: Trading amount (not used for prediction) 75 | - `timestamps`/`timestamp`/`date`: Timestamp 76 | 77 | ## 🤖 Model Support 78 | 79 | - **Kronos-mini**: 4.1M parameters, lightweight fast prediction 80 | - **Kronos-small**: 24.7M parameters, balanced performance and speed 81 | - **Kronos-base**: 102.3M parameters, high quality prediction 82 | 83 | ## 🖥️ GPU Acceleration Support 84 | 85 | - **CPU**: General computing, best compatibility 86 | - **CUDA**: NVIDIA GPU acceleration, best performance 87 | - **MPS**: Apple Silicon GPU acceleration, recommended for Mac users 88 | 89 | ## ⚠️ Notes 90 | 91 | - `amount` column is not used for prediction, only for display 92 | - Time window is fixed at 400+120=520 data points 93 | - Ensure data file contains sufficient historical data 94 | - First model loading may require download, please be patient 95 | 96 | ## 🔍 Comparison Analysis 97 | 98 | The system automatically provides comparison analysis between prediction results and actual data, including: 99 | - Price difference statistics 100 | - Error analysis 101 | - Prediction quality assessment 102 | 103 | ## 🛠️ Technical Architecture 104 | 105 | - **Backend**: Flask + Python 106 | - **Frontend**: HTML + CSS + JavaScript 107 | - **Charts**: Plotly.js 108 | - **Data processing**: Pandas + NumPy 109 | - **Model**: Hugging Face Transformers 110 | 111 | ## 📝 Troubleshooting 112 | 113 | ### Common Issues 114 | 1. **Port occupied**: Modify port number in app.py 115 | 2. **Missing dependencies**: Run `pip install -r requirements.txt` 116 | 3. **Model loading failed**: Check network connection and model ID 117 | 4. **Data format error**: Ensure data column names and format are correct 118 | 119 | ### Log Viewing 120 | Detailed runtime information will be displayed in the console at startup, including model status and error messages. 121 | 122 | ## 📄 License 123 | 124 | This project follows the license terms of the original Kronos project. 125 | 126 | ## 🤝 Contributing 127 | 128 | Welcome to submit Issues and Pull Requests to improve this Web UI! 129 | 130 | ## 📞 Support 131 | 132 | If you have questions, please check: 133 | 1. Project documentation 134 | 2. GitHub Issues 135 | 3. Console error messages 136 | -------------------------------------------------------------------------------- /finetune/qlib_data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import pandas as pd 5 | import qlib 6 | from qlib.config import REG_CN 7 | from qlib.data import D 8 | from qlib.data.dataset.loader import QlibDataLoader 9 | from tqdm import trange 10 | 11 | from config import Config 12 | 13 | 14 | class QlibDataPreprocessor: 15 | """ 16 | A class to handle the loading, processing, and splitting of Qlib financial data. 17 | """ 18 | 19 | def __init__(self): 20 | """Initializes the preprocessor with configuration and data fields.""" 21 | self.config = Config() 22 | self.data_fields = ['open', 'close', 'high', 'low', 'volume', 'vwap'] 23 | self.data = {} # A dictionary to store processed data for each symbol. 24 | 25 | def initialize_qlib(self): 26 | """Initializes the Qlib environment.""" 27 | print("Initializing Qlib...") 28 | qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN) 29 | 30 | def load_qlib_data(self): 31 | """ 32 | Loads raw data from Qlib, processes it symbol by symbol, and stores 33 | it in the `self.data` attribute. 34 | """ 35 | print("Loading and processing data from Qlib...") 36 | data_fields_qlib = ['$' + f for f in self.data_fields] 37 | cal: np.ndarray = D.calendar() 38 | 39 | # Determine the actual start and end times to load, including buffer for lookback and predict windows. 40 | start_index = cal.searchsorted(pd.Timestamp(self.config.dataset_begin_time)) 41 | end_index = cal.searchsorted(pd.Timestamp(self.config.dataset_end_time)) 42 | 43 | # Check if start_index lookbackw_window will cause negative index 44 | adjusted_start_index = max(start_index - self.config.lookback_window, 0) 45 | real_start_time = cal[adjusted_start_index] 46 | 47 | # Check if end_index exceeds the range of the array 48 | if end_index >= len(cal): 49 | end_index = len(cal) - 1 50 | elif cal[end_index] != pd.Timestamp(self.config.dataset_end_time): 51 | end_index -= 1 52 | 53 | # Check if end_index+predictw_window will exceed the range of the array 54 | adjusted_end_index = min(end_index + self.config.predict_window, len(cal) - 1) 55 | real_end_time = cal[adjusted_end_index] 56 | 57 | # Load data using Qlib's data loader. 58 | data_df = QlibDataLoader(config=data_fields_qlib).load( 59 | self.config.instrument, real_start_time, real_end_time 60 | ) 61 | data_df = data_df.stack().unstack(level=1) # Reshape for easier access. 62 | 63 | symbol_list = list(data_df.columns) 64 | for i in trange(len(symbol_list), desc="Processing Symbols"): 65 | symbol = symbol_list[i] 66 | symbol_df = data_df[symbol] 67 | 68 | # Pivot the table to have features as columns and datetime as index. 69 | symbol_df = symbol_df.reset_index().rename(columns={'level_1': 'field'}) 70 | symbol_df = pd.pivot(symbol_df, index='datetime', columns='field', values=symbol) 71 | symbol_df = symbol_df.rename(columns={f'${field}': field for field in self.data_fields}) 72 | 73 | # Calculate amount and select final features. 74 | symbol_df['vol'] = symbol_df['volume'] 75 | symbol_df['amt'] = (symbol_df['open'] + symbol_df['high'] + symbol_df['low'] + symbol_df['close']) / 4 * symbol_df['vol'] 76 | symbol_df = symbol_df[self.config.feature_list] 77 | 78 | # Filter out symbols with insufficient data. 79 | symbol_df = symbol_df.dropna() 80 | if len(symbol_df) < self.config.lookback_window + self.config.predict_window + 1: 81 | continue 82 | 83 | self.data[symbol] = symbol_df 84 | 85 | def prepare_dataset(self): 86 | """ 87 | Splits the loaded data into train, validation, and test sets and saves them to disk. 88 | """ 89 | print("Splitting data into train, validation, and test sets...") 90 | train_data, val_data, test_data = {}, {}, {} 91 | 92 | symbol_list = list(self.data.keys()) 93 | for i in trange(len(symbol_list), desc="Preparing Datasets"): 94 | symbol = symbol_list[i] 95 | symbol_df = self.data[symbol] 96 | 97 | # Define time ranges from config. 98 | train_start, train_end = self.config.train_time_range 99 | val_start, val_end = self.config.val_time_range 100 | test_start, test_end = self.config.test_time_range 101 | 102 | # Create boolean masks for each dataset split. 103 | train_mask = (symbol_df.index >= train_start) & (symbol_df.index <= train_end) 104 | val_mask = (symbol_df.index >= val_start) & (symbol_df.index <= val_end) 105 | test_mask = (symbol_df.index >= test_start) & (symbol_df.index <= test_end) 106 | 107 | # Apply masks to create the final datasets. 108 | train_data[symbol] = symbol_df[train_mask] 109 | val_data[symbol] = symbol_df[val_mask] 110 | test_data[symbol] = symbol_df[test_mask] 111 | 112 | # Save the datasets using pickle. 113 | os.makedirs(self.config.dataset_path, exist_ok=True) 114 | with open(f"{self.config.dataset_path}/train_data.pkl", 'wb') as f: 115 | pickle.dump(train_data, f) 116 | with open(f"{self.config.dataset_path}/val_data.pkl", 'wb') as f: 117 | pickle.dump(val_data, f) 118 | with open(f"{self.config.dataset_path}/test_data.pkl", 'wb') as f: 119 | pickle.dump(test_data, f) 120 | 121 | print("Datasets prepared and saved successfully.") 122 | 123 | 124 | if __name__ == '__main__': 125 | # This block allows the script to be run directly to perform data preprocessing. 126 | preprocessor = QlibDataPreprocessor() 127 | preprocessor.initialize_qlib() 128 | preprocessor.load_qlib_data() 129 | preprocessor.prepare_dataset() 130 | 131 | -------------------------------------------------------------------------------- /finetune/dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from config import Config 7 | 8 | 9 | class QlibDataset(Dataset): 10 | """ 11 | A PyTorch Dataset for handling Qlib financial time series data. 12 | 13 | This dataset pre-computes all possible start indices for sliding windows 14 | and then randomly samples from them during training/validation. 15 | 16 | Args: 17 | data_type (str): The type of dataset to load, either 'train' or 'val'. 18 | 19 | Raises: 20 | ValueError: If `data_type` is not 'train' or 'val'. 21 | """ 22 | 23 | def __init__(self, data_type: str = 'train'): 24 | self.config = Config() 25 | if data_type not in ['train', 'val']: 26 | raise ValueError("data_type must be 'train' or 'val'") 27 | self.data_type = data_type 28 | 29 | # Use a dedicated random number generator for sampling to avoid 30 | # interfering with other random processes (e.g., in model initialization). 31 | self.py_rng = random.Random(self.config.seed) 32 | 33 | # Set paths and number of samples based on the data type. 34 | if data_type == 'train': 35 | self.data_path = f"{self.config.dataset_path}/train_data.pkl" 36 | self.n_samples = self.config.n_train_iter 37 | else: 38 | self.data_path = f"{self.config.dataset_path}/val_data.pkl" 39 | self.n_samples = self.config.n_val_iter 40 | 41 | with open(self.data_path, 'rb') as f: 42 | self.data = pickle.load(f) 43 | 44 | self.window = self.config.lookback_window + self.config.predict_window + 1 45 | 46 | self.symbols = list(self.data.keys()) 47 | self.feature_list = self.config.feature_list 48 | self.time_feature_list = self.config.time_feature_list 49 | 50 | # Pre-compute all possible (symbol, start_index) pairs. 51 | self.indices = [] 52 | print(f"[{data_type.upper()}] Pre-computing sample indices...") 53 | for symbol in self.symbols: 54 | df = self.data[symbol].reset_index() 55 | series_len = len(df) 56 | num_samples = series_len - self.window + 1 57 | 58 | if num_samples > 0: 59 | # Generate time features and store them directly in the dataframe. 60 | df['minute'] = df['datetime'].dt.minute 61 | df['hour'] = df['datetime'].dt.hour 62 | df['weekday'] = df['datetime'].dt.weekday 63 | df['day'] = df['datetime'].dt.day 64 | df['month'] = df['datetime'].dt.month 65 | # Keep only necessary columns to save memory. 66 | self.data[symbol] = df[self.feature_list + self.time_feature_list] 67 | 68 | # Add all valid starting indices for this symbol to the global list. 69 | for i in range(num_samples): 70 | self.indices.append((symbol, i)) 71 | 72 | # The effective dataset size is the minimum of the configured iterations 73 | # and the total number of available samples. 74 | self.n_samples = min(self.n_samples, len(self.indices)) 75 | print(f"[{data_type.upper()}] Found {len(self.indices)} possible samples. Using {self.n_samples} per epoch.") 76 | 77 | def set_epoch_seed(self, epoch: int): 78 | """ 79 | Sets a new seed for the random sampler for each epoch. This is crucial 80 | for reproducibility in distributed training. 81 | 82 | Args: 83 | epoch (int): The current epoch number. 84 | """ 85 | epoch_seed = self.config.seed + epoch 86 | self.py_rng.seed(epoch_seed) 87 | 88 | def __len__(self) -> int: 89 | """Returns the number of samples per epoch.""" 90 | return self.n_samples 91 | 92 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: 93 | """ 94 | Retrieves a random sample from the dataset. 95 | 96 | Note: The `idx` argument is ignored. Instead, a random index is drawn 97 | from the pre-computed `self.indices` list using `self.py_rng`. This 98 | ensures random sampling over the entire dataset for each call. 99 | 100 | Args: 101 | idx (int): Ignored. 102 | 103 | Returns: 104 | tuple[torch.Tensor, torch.Tensor]: A tuple containing: 105 | - x_tensor (torch.Tensor): The normalized feature tensor. 106 | - x_stamp_tensor (torch.Tensor): The time feature tensor. 107 | """ 108 | # Select a random sample from the entire pool of indices. 109 | random_idx = self.py_rng.randint(0, len(self.indices) - 1) 110 | symbol, start_idx = self.indices[random_idx] 111 | 112 | # Extract the sliding window from the dataframe. 113 | df = self.data[symbol] 114 | end_idx = start_idx + self.window 115 | win_df = df.iloc[start_idx:end_idx] 116 | 117 | # Separate main features and time features. 118 | x = win_df[self.feature_list].values.astype(np.float32) 119 | x_stamp = win_df[self.time_feature_list].values.astype(np.float32) 120 | 121 | # Perform instance-level normalization. 122 | x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) 123 | x = (x - x_mean) / (x_std + 1e-5) 124 | x = np.clip(x, -self.config.clip, self.config.clip) 125 | 126 | # Convert to PyTorch tensors. 127 | x_tensor = torch.from_numpy(x) 128 | x_stamp_tensor = torch.from_numpy(x_stamp) 129 | 130 | return x_tensor, x_stamp_tensor 131 | 132 | 133 | if __name__ == '__main__': 134 | # Example usage and verification. 135 | print("Creating training dataset instance...") 136 | train_dataset = QlibDataset(data_type='train') 137 | 138 | print(f"Dataset length: {len(train_dataset)}") 139 | 140 | if len(train_dataset) > 0: 141 | try_x, try_x_stamp = train_dataset[100] # Index 100 is ignored. 142 | print(f"Sample feature shape: {try_x.shape}") 143 | print(f"Sample time feature shape: {try_x_stamp.shape}") 144 | else: 145 | print("Dataset is empty.") 146 | -------------------------------------------------------------------------------- /finetune/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Config: 4 | """ 5 | Configuration class for the entire project. 6 | """ 7 | 8 | def __init__(self): 9 | # ================================================================= 10 | # Data & Feature Parameters 11 | # ================================================================= 12 | # TODO: Update this path to your Qlib data directory. 13 | self.qlib_data_path = "~/.qlib/qlib_data/cn_data" 14 | self.instrument = 'csi300' 15 | 16 | # Overall time range for data loading from Qlib. 17 | self.dataset_begin_time = "2011-01-01" 18 | self.dataset_end_time = '2025-06-05' 19 | 20 | # Sliding window parameters for creating samples. 21 | self.lookback_window = 90 # Number of past time steps for input. 22 | self.predict_window = 10 # Number of future time steps for prediction. 23 | self.max_context = 512 # Maximum context length for the model. 24 | 25 | # Features to be used from the raw data. 26 | self.feature_list = ['open', 'high', 'low', 'close', 'vol', 'amt'] 27 | # Time-based features to be generated. 28 | self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month'] 29 | 30 | # ================================================================= 31 | # Dataset Splitting & Paths 32 | # ================================================================= 33 | # Note: The validation/test set starts earlier than the training/validation set ends 34 | # to account for the `lookback_window`. 35 | self.train_time_range = ["2011-01-01", "2022-12-31"] 36 | self.val_time_range = ["2022-09-01", "2024-06-30"] 37 | self.test_time_range = ["2024-04-01", "2025-06-05"] 38 | self.backtest_time_range = ["2024-07-01", "2025-06-05"] 39 | 40 | # TODO: Directory to save the processed, pickled datasets. 41 | self.dataset_path = "./data/processed_datasets" 42 | 43 | # ================================================================= 44 | # Training Hyperparameters 45 | # ================================================================= 46 | self.clip = 5.0 # Clipping value for normalized data to prevent outliers. 47 | 48 | self.epochs = 30 49 | self.log_interval = 100 # Log training status every N batches. 50 | self.batch_size = 50 # Batch size per GPU. 51 | 52 | # Number of samples to draw for one "epoch" of training/validation. 53 | # This is useful for large datasets where a true epoch is too long. 54 | self.n_train_iter = 2000 * self.batch_size 55 | self.n_val_iter = 400 * self.batch_size 56 | 57 | # Learning rates for different model components. 58 | self.tokenizer_learning_rate = 2e-4 59 | self.predictor_learning_rate = 4e-5 60 | 61 | # Gradient accumulation to simulate a larger batch size. 62 | self.accumulation_steps = 1 63 | 64 | # AdamW optimizer parameters. 65 | self.adam_beta1 = 0.9 66 | self.adam_beta2 = 0.95 67 | self.adam_weight_decay = 0.1 68 | 69 | # Miscellaneous 70 | self.seed = 100 # Global random seed for reproducibility. 71 | 72 | # ================================================================= 73 | # Experiment Logging & Saving 74 | # ================================================================= 75 | self.use_comet = True # Set to False if you don't want to use Comet ML 76 | self.comet_config = { 77 | # It is highly recommended to load secrets from environment variables 78 | # for security purposes. Example: os.getenv("COMET_API_KEY") 79 | "api_key": "YOUR_COMET_API_KEY", 80 | "project_name": "Kronos-Finetune-Demo", 81 | "workspace": "your_comet_workspace" # TODO: Change to your Comet ML workspace name 82 | } 83 | self.comet_tag = 'finetune_demo' 84 | self.comet_name = 'finetune_demo' 85 | 86 | # Base directory for saving model checkpoints and results. 87 | # Using a general 'outputs' directory is a common practice. 88 | self.save_path = "./outputs/models" 89 | self.tokenizer_save_folder_name = 'finetune_tokenizer_demo' 90 | self.predictor_save_folder_name = 'finetune_predictor_demo' 91 | self.backtest_save_folder_name = 'finetune_backtest_demo' 92 | 93 | # Path for backtesting results. 94 | self.backtest_result_path = "./outputs/backtest_results" 95 | 96 | # ================================================================= 97 | # Model & Checkpoint Paths 98 | # ================================================================= 99 | # TODO: Update these paths to your pretrained model locations. 100 | # These can be local paths or Hugging Face Hub model identifiers. 101 | self.pretrained_tokenizer_path = "path/to/your/Kronos-Tokenizer-base" 102 | self.pretrained_predictor_path = "path/to/your/Kronos-small" 103 | 104 | # Paths to the fine-tuned models, derived from the save_path. 105 | # These will be generated automatically during training. 106 | self.finetuned_tokenizer_path = f"{self.save_path}/{self.tokenizer_save_folder_name}/checkpoints/best_model" 107 | self.finetuned_predictor_path = f"{self.save_path}/{self.predictor_save_folder_name}/checkpoints/best_model" 108 | 109 | # ================================================================= 110 | # Backtesting Parameters 111 | # ================================================================= 112 | self.backtest_n_symbol_hold = 50 # Number of symbols to hold in the portfolio. 113 | self.backtest_n_symbol_drop = 5 # Number of symbols to drop from the pool. 114 | self.backtest_hold_thresh = 5 # Minimum holding period for a stock. 115 | self.inference_T = 0.6 116 | self.inference_top_p = 0.9 117 | self.inference_top_k = 0 118 | self.inference_sample_count = 5 119 | self.backtest_batch_size = 1000 120 | self.backtest_benchmark = self._set_benchmark(self.instrument) 121 | 122 | def _set_benchmark(self, instrument): 123 | dt_benchmark = { 124 | 'csi800': "SH000906", 125 | 'csi1000': "SH000852", 126 | 'csi300': "SH000300", 127 | } 128 | if instrument in dt_benchmark: 129 | return dt_benchmark[instrument] 130 | else: 131 | raise ValueError(f"Benchmark not defined for instrument: {instrument}") 132 | -------------------------------------------------------------------------------- /finetune/train_predictor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | from time import gmtime, strftime 6 | import torch.distributed as dist 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | 12 | import comet_ml 13 | 14 | # Ensure project root is in path 15 | sys.path.append('../') 16 | from config import Config 17 | from dataset import QlibDataset 18 | from model.kronos import KronosTokenizer, Kronos 19 | # Import shared utilities 20 | from utils.training_utils import ( 21 | setup_ddp, 22 | cleanup_ddp, 23 | set_seed, 24 | get_model_size, 25 | format_time 26 | ) 27 | 28 | 29 | def create_dataloaders(config: dict, rank: int, world_size: int): 30 | """ 31 | Creates and returns distributed dataloaders for training and validation. 32 | 33 | Args: 34 | config (dict): A dictionary of configuration parameters. 35 | rank (int): The global rank of the current process. 36 | world_size (int): The total number of processes. 37 | 38 | Returns: 39 | tuple: (train_loader, val_loader, train_dataset, valid_dataset). 40 | """ 41 | print(f"[Rank {rank}] Creating distributed dataloaders...") 42 | train_dataset = QlibDataset('train') 43 | valid_dataset = QlibDataset('val') 44 | print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}") 45 | 46 | train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) 47 | val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False) 48 | 49 | train_loader = DataLoader( 50 | train_dataset, batch_size=config['batch_size'], sampler=train_sampler, 51 | num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=True 52 | ) 53 | val_loader = DataLoader( 54 | valid_dataset, batch_size=config['batch_size'], sampler=val_sampler, 55 | num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=False 56 | ) 57 | return train_loader, val_loader, train_dataset, valid_dataset 58 | 59 | 60 | def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_size): 61 | """ 62 | The main training and validation loop for the predictor. 63 | """ 64 | start_time = time.time() 65 | if rank == 0: 66 | effective_bs = config['batch_size'] * world_size 67 | print(f"Effective BATCHSIZE per GPU: {config['batch_size']}, Total: {effective_bs}") 68 | 69 | train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size) 70 | 71 | optimizer = torch.optim.AdamW( 72 | model.parameters(), 73 | lr=config['predictor_learning_rate'], 74 | betas=(config['adam_beta1'], config['adam_beta2']), 75 | weight_decay=config['adam_weight_decay'] 76 | ) 77 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 78 | optimizer, max_lr=config['predictor_learning_rate'], 79 | steps_per_epoch=len(train_loader), epochs=config['epochs'], 80 | pct_start=0.03, div_factor=10 81 | ) 82 | 83 | best_val_loss = float('inf') 84 | dt_result = {} 85 | batch_idx_global = 0 86 | 87 | for epoch_idx in range(config['epochs']): 88 | epoch_start_time = time.time() 89 | model.train() 90 | train_loader.sampler.set_epoch(epoch_idx) 91 | 92 | train_dataset.set_epoch_seed(epoch_idx * 10000 + rank) 93 | valid_dataset.set_epoch_seed(0) 94 | 95 | for i, (batch_x, batch_x_stamp) in enumerate(train_loader): 96 | batch_x = batch_x.squeeze(0).to(device, non_blocking=True) 97 | batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True) 98 | 99 | # Tokenize input data on-the-fly 100 | with torch.no_grad(): 101 | token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) 102 | 103 | # Prepare inputs and targets for the language model 104 | token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] 105 | token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] 106 | 107 | # Forward pass and loss calculation 108 | logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) 109 | loss, s1_loss, s2_loss = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) 110 | 111 | # Backward pass and optimization 112 | optimizer.zero_grad() 113 | loss.backward() 114 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0) 115 | optimizer.step() 116 | scheduler.step() 117 | 118 | # Logging (Master Process Only) 119 | if rank == 0 and (batch_idx_global + 1) % config['log_interval'] == 0: 120 | lr = optimizer.param_groups[0]['lr'] 121 | print( 122 | f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] " 123 | f"LR {lr:.6f}, Loss: {loss.item():.4f}" 124 | ) 125 | if rank == 0 and logger: 126 | lr = optimizer.param_groups[0]['lr'] 127 | logger.log_metric('train_predictor_loss_batch', loss.item(), step=batch_idx_global) 128 | logger.log_metric('train_S1_loss_each_batch', s1_loss.item(), step=batch_idx_global) 129 | logger.log_metric('train_S2_loss_each_batch', s2_loss.item(), step=batch_idx_global) 130 | logger.log_metric('predictor_learning_rate', lr, step=batch_idx_global) 131 | 132 | batch_idx_global += 1 133 | 134 | # --- Validation Loop --- 135 | model.eval() 136 | tot_val_loss_sum_rank = 0.0 137 | val_batches_processed_rank = 0 138 | with torch.no_grad(): 139 | for batch_x, batch_x_stamp in val_loader: 140 | batch_x = batch_x.squeeze(0).to(device, non_blocking=True) 141 | batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True) 142 | 143 | token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) 144 | token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] 145 | token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] 146 | 147 | logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) 148 | val_loss, _, _ = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) 149 | 150 | tot_val_loss_sum_rank += val_loss.item() 151 | val_batches_processed_rank += 1 152 | 153 | # Reduce validation metrics 154 | val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device) 155 | val_batches_tensor = torch.tensor(val_batches_processed_rank, device=device) 156 | dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM) 157 | dist.all_reduce(val_batches_tensor, op=dist.ReduceOp.SUM) 158 | 159 | avg_val_loss = val_loss_sum_tensor.item() / val_batches_tensor.item() if val_batches_tensor.item() > 0 else 0 160 | 161 | # --- End of Epoch Summary & Checkpointing (Master Process Only) --- 162 | if rank == 0: 163 | print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---") 164 | print(f"Validation Loss: {avg_val_loss:.4f}") 165 | print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}") 166 | print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n") 167 | if logger: 168 | logger.log_metric('val_predictor_loss_epoch', avg_val_loss, epoch=epoch_idx) 169 | 170 | if avg_val_loss < best_val_loss: 171 | best_val_loss = avg_val_loss 172 | save_path = f"{save_dir}/checkpoints/best_model" 173 | model.module.save_pretrained(save_path) 174 | print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})") 175 | 176 | dist.barrier() 177 | 178 | dt_result['best_val_loss'] = best_val_loss 179 | return dt_result 180 | 181 | 182 | def main(config: dict): 183 | """Main function to orchestrate the DDP training process.""" 184 | rank, world_size, local_rank = setup_ddp() 185 | device = torch.device(f"cuda:{local_rank}") 186 | set_seed(config['seed'], rank) 187 | 188 | save_dir = os.path.join(config['save_path'], config['predictor_save_folder_name']) 189 | 190 | # Logger and summary setup (master process only) 191 | comet_logger, master_summary = None, {} 192 | if rank == 0: 193 | os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True) 194 | master_summary = { 195 | 'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()), 196 | 'save_directory': save_dir, 197 | 'world_size': world_size, 198 | } 199 | if config['use_comet']: 200 | comet_logger = comet_ml.Experiment( 201 | api_key=config['comet_config']['api_key'], 202 | project_name=config['comet_config']['project_name'], 203 | workspace=config['comet_config']['workspace'], 204 | ) 205 | comet_logger.add_tag(config['comet_tag']) 206 | comet_logger.set_name(config['comet_name']) 207 | comet_logger.log_parameters(config) 208 | print("Comet Logger Initialized.") 209 | 210 | dist.barrier() 211 | 212 | # Model Initialization 213 | tokenizer = KronosTokenizer.from_pretrained(config['finetuned_tokenizer_path']) 214 | tokenizer.eval().to(device) 215 | 216 | model = Kronos.from_pretrained(config['pretrained_predictor_path']) 217 | model.to(device) 218 | model = DDP(model, device_ids=[local_rank], find_unused_parameters=False) 219 | 220 | if rank == 0: 221 | print(f"Predictor Model Size: {get_model_size(model.module)}") 222 | 223 | # Start Training 224 | dt_result = train_model( 225 | model, tokenizer, device, config, save_dir, comet_logger, rank, world_size 226 | ) 227 | 228 | if rank == 0: 229 | master_summary['final_result'] = dt_result 230 | with open(os.path.join(save_dir, 'summary.json'), 'w') as f: 231 | json.dump(master_summary, f, indent=4) 232 | print('Training finished. Summary file saved.') 233 | if comet_logger: comet_logger.end() 234 | 235 | cleanup_ddp() 236 | 237 | 238 | if __name__ == '__main__': 239 | # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_predictor.py 240 | if "WORLD_SIZE" not in os.environ: 241 | raise RuntimeError("This script must be launched with `torchrun`.") 242 | 243 | config_instance = Config() 244 | main(config_instance.__dict__) 245 | -------------------------------------------------------------------------------- /finetune/train_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | from time import gmtime, strftime 6 | import argparse 7 | import datetime 8 | import torch.distributed as dist 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | 15 | import comet_ml 16 | 17 | # Ensure project root is in path 18 | sys.path.append("../") 19 | from config import Config 20 | from dataset import QlibDataset 21 | from model.kronos import KronosTokenizer 22 | # Import shared utilities 23 | from utils.training_utils import ( 24 | setup_ddp, 25 | cleanup_ddp, 26 | set_seed, 27 | get_model_size, 28 | format_time, 29 | ) 30 | 31 | 32 | def create_dataloaders(config: dict, rank: int, world_size: int): 33 | """ 34 | Creates and returns distributed dataloaders for training and validation. 35 | 36 | Args: 37 | config (dict): A dictionary of configuration parameters. 38 | rank (int): The global rank of the current process. 39 | world_size (int): The total number of processes. 40 | 41 | Returns: 42 | tuple: A tuple containing (train_loader, val_loader, train_dataset, valid_dataset). 43 | """ 44 | print(f"[Rank {rank}] Creating distributed dataloaders...") 45 | train_dataset = QlibDataset('train') 46 | valid_dataset = QlibDataset('val') 47 | print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}") 48 | 49 | train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) 50 | val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False) 51 | 52 | train_loader = DataLoader( 53 | train_dataset, 54 | batch_size=config['batch_size'], 55 | sampler=train_sampler, 56 | shuffle=False, # Shuffle is handled by the sampler 57 | num_workers=config.get('num_workers', 2), 58 | pin_memory=True, 59 | drop_last=True 60 | ) 61 | val_loader = DataLoader( 62 | valid_dataset, 63 | batch_size=config['batch_size'], 64 | sampler=val_sampler, 65 | shuffle=False, 66 | num_workers=config.get('num_workers', 2), 67 | pin_memory=True, 68 | drop_last=False 69 | ) 70 | print(f"[Rank {rank}] Dataloaders created. Train steps/epoch: {len(train_loader)}, Val steps: {len(val_loader)}") 71 | return train_loader, val_loader, train_dataset, valid_dataset 72 | 73 | 74 | def train_model(model, device, config, save_dir, logger, rank, world_size): 75 | """ 76 | The main training and validation loop for the tokenizer. 77 | 78 | Args: 79 | model (DDP): The DDP-wrapped model to train. 80 | device (torch.device): The device for the current process. 81 | config (dict): Configuration dictionary. 82 | save_dir (str): Directory to save checkpoints. 83 | logger (comet_ml.Experiment): Comet logger instance. 84 | rank (int): Global rank of the process. 85 | world_size (int): Total number of processes. 86 | 87 | Returns: 88 | tuple: A tuple containing the trained model and a dictionary of results. 89 | """ 90 | start_time = time.time() 91 | if rank == 0: 92 | effective_bs = config['batch_size'] * world_size * config['accumulation_steps'] 93 | print(f"[Rank {rank}] BATCHSIZE (per GPU): {config['batch_size']}") 94 | print(f"[Rank {rank}] Effective total batch size: {effective_bs}") 95 | 96 | train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size) 97 | 98 | optimizer = torch.optim.AdamW( 99 | model.parameters(), 100 | lr=config['tokenizer_learning_rate'], 101 | weight_decay=config['adam_weight_decay'] 102 | ) 103 | 104 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 105 | optimizer=optimizer, 106 | max_lr=config['tokenizer_learning_rate'], 107 | steps_per_epoch=len(train_loader), 108 | epochs=config['epochs'], 109 | pct_start=0.03, 110 | div_factor=10 111 | ) 112 | 113 | best_val_loss = float('inf') 114 | dt_result = {} 115 | batch_idx_global_train = 0 116 | 117 | for epoch_idx in range(config['epochs']): 118 | epoch_start_time = time.time() 119 | model.train() 120 | train_loader.sampler.set_epoch(epoch_idx) 121 | 122 | # Set dataset seeds for reproducible sampling 123 | train_dataset.set_epoch_seed(epoch_idx * 10000 + rank) 124 | valid_dataset.set_epoch_seed(0) # Keep validation sampling consistent 125 | 126 | for i, (ori_batch_x, _) in enumerate(train_loader): 127 | ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) 128 | 129 | # --- Gradient Accumulation Loop --- 130 | current_batch_total_loss = 0.0 131 | for j in range(config['accumulation_steps']): 132 | start_idx = j * (ori_batch_x.shape[0] // config['accumulation_steps']) 133 | end_idx = (j + 1) * (ori_batch_x.shape[0] // config['accumulation_steps']) 134 | batch_x = ori_batch_x[start_idx:end_idx] 135 | 136 | # Forward pass 137 | zs, bsq_loss, _, _ = model(batch_x) 138 | z_pre, z = zs 139 | 140 | # Loss calculation 141 | recon_loss_pre = F.mse_loss(z_pre, batch_x) 142 | recon_loss_all = F.mse_loss(z, batch_x) 143 | recon_loss = recon_loss_pre + recon_loss_all 144 | loss = (recon_loss + bsq_loss) / 2 # Assuming w_1=w_2=1 145 | 146 | loss_scaled = loss / config['accumulation_steps'] 147 | current_batch_total_loss += loss.item() 148 | loss_scaled.backward() 149 | 150 | # --- Optimizer Step after Accumulation --- 151 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0) 152 | optimizer.step() 153 | scheduler.step() 154 | optimizer.zero_grad() 155 | 156 | # --- Logging (Master Process Only) --- 157 | if rank == 0 and (batch_idx_global_train + 1) % config['log_interval'] == 0: 158 | avg_loss = current_batch_total_loss / config['accumulation_steps'] 159 | print( 160 | f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] " 161 | f"LR {optimizer.param_groups[0]['lr']:.6f}, Loss: {avg_loss:.4f}" 162 | ) 163 | if rank == 0 and logger: 164 | avg_loss = current_batch_total_loss / config['accumulation_steps'] 165 | logger.log_metric('train_tokenizer_loss_batch', avg_loss, step=batch_idx_global_train) 166 | logger.log_metric(f'train_vqvae_vq_loss_each_batch', bsq_loss.item(), step=batch_idx_global_train) 167 | logger.log_metric(f'train_recon_loss_pre_each_batch', recon_loss_pre.item(), step=batch_idx_global_train) 168 | logger.log_metric(f'train_recon_loss_each_batch', recon_loss_all.item(), step=batch_idx_global_train) 169 | logger.log_metric('tokenizer_learning_rate', optimizer.param_groups[0]["lr"], step=batch_idx_global_train) 170 | 171 | batch_idx_global_train += 1 172 | 173 | # --- Validation Loop --- 174 | model.eval() 175 | tot_val_loss_sum_rank = 0.0 176 | val_sample_count_rank = 0 177 | with torch.no_grad(): 178 | for ori_batch_x, _ in val_loader: 179 | ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) 180 | zs, _, _, _ = model(ori_batch_x) 181 | _, z = zs 182 | val_loss_item = F.mse_loss(z, ori_batch_x) 183 | 184 | tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0) 185 | val_sample_count_rank += ori_batch_x.size(0) 186 | 187 | # Reduce validation losses from all processes 188 | val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device) 189 | val_count_tensor = torch.tensor(val_sample_count_rank, device=device) 190 | dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM) 191 | dist.all_reduce(val_count_tensor, op=dist.ReduceOp.SUM) 192 | 193 | avg_val_loss = val_loss_sum_tensor.item() / val_count_tensor.item() if val_count_tensor.item() > 0 else 0 194 | 195 | # --- End of Epoch Summary & Checkpointing (Master Process Only) --- 196 | if rank == 0: 197 | print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---") 198 | print(f"Validation Loss: {avg_val_loss:.4f}") 199 | print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}") 200 | print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n") 201 | if logger: 202 | logger.log_metric('val_tokenizer_loss_epoch', avg_val_loss, epoch=epoch_idx) 203 | 204 | if avg_val_loss < best_val_loss: 205 | best_val_loss = avg_val_loss 206 | save_path = f"{save_dir}/checkpoints/best_model" 207 | model.module.save_pretrained(save_path) 208 | print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})") 209 | if logger: 210 | logger.log_model("best_model", save_path) 211 | 212 | dist.barrier() # Ensure all processes finish the epoch before starting the next one. 213 | 214 | dt_result['best_val_loss'] = best_val_loss 215 | return model, dt_result 216 | 217 | 218 | def main(config: dict): 219 | """ 220 | Main function to orchestrate the DDP training process. 221 | """ 222 | rank, world_size, local_rank = setup_ddp() 223 | device = torch.device(f"cuda:{local_rank}") 224 | set_seed(config['seed'], rank) 225 | 226 | save_dir = os.path.join(config['save_path'], config['tokenizer_save_folder_name']) 227 | 228 | # Logger and summary setup (master process only) 229 | comet_logger, master_summary = None, {} 230 | if rank == 0: 231 | os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True) 232 | master_summary = { 233 | 'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()), 234 | 'save_directory': save_dir, 235 | 'world_size': world_size, 236 | } 237 | if config['use_comet']: 238 | comet_logger = comet_ml.Experiment( 239 | api_key=config['comet_config']['api_key'], 240 | project_name=config['comet_config']['project_name'], 241 | workspace=config['comet_config']['workspace'], 242 | ) 243 | comet_logger.add_tag(config['comet_tag']) 244 | comet_logger.set_name(config['comet_name']) 245 | comet_logger.log_parameters(config) 246 | print("Comet Logger Initialized.") 247 | 248 | dist.barrier() # Ensure save directory is created before proceeding 249 | 250 | # Model Initialization 251 | model = KronosTokenizer.from_pretrained(config['pretrained_tokenizer_path']) 252 | model.to(device) 253 | model = DDP(model, device_ids=[local_rank], find_unused_parameters=False) 254 | 255 | if rank == 0: 256 | print(f"Model Size: {get_model_size(model.module)}") 257 | 258 | # Start Training 259 | _, dt_result = train_model( 260 | model, device, config, save_dir, comet_logger, rank, world_size 261 | ) 262 | 263 | # Finalize and save summary (master process only) 264 | if rank == 0: 265 | master_summary['final_result'] = dt_result 266 | with open(os.path.join(save_dir, 'summary.json'), 'w') as f: 267 | json.dump(master_summary, f, indent=4) 268 | print('Training finished. Summary file saved.') 269 | if comet_logger: 270 | comet_logger.end() 271 | 272 | cleanup_ddp() 273 | 274 | 275 | if __name__ == '__main__': 276 | # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_tokenizer.py 277 | if "WORLD_SIZE" not in os.environ: 278 | raise RuntimeError("This script must be launched with `torchrun`.") 279 | 280 | config_instance = Config() 281 | main(config_instance.__dict__) 282 | -------------------------------------------------------------------------------- /finetune_csv/config_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from typing import Dict, Any 4 | 5 | 6 | class ConfigLoader: 7 | 8 | def __init__(self, config_path: str): 9 | 10 | self.config_path = config_path 11 | self.config = self._load_config() 12 | 13 | def _load_config(self) -> Dict[str, Any]: 14 | 15 | if not os.path.exists(self.config_path): 16 | raise FileNotFoundError(f"config file not found: {self.config_path}") 17 | 18 | with open(self.config_path, 'r', encoding='utf-8') as f: 19 | config = yaml.safe_load(f) 20 | 21 | config = self._resolve_dynamic_paths(config) 22 | 23 | return config 24 | 25 | def _resolve_dynamic_paths(self, config: Dict[str, Any]) -> Dict[str, Any]: 26 | 27 | exp_name = config.get('model_paths', {}).get('exp_name', '') 28 | if not exp_name: 29 | return config 30 | 31 | base_path = config.get('model_paths', {}).get('base_path', '') 32 | path_templates = { 33 | 'base_save_path': f"{base_path}/{exp_name}", 34 | 'finetuned_tokenizer': f"{base_path}/{exp_name}/tokenizer/best_model" 35 | } 36 | 37 | if 'model_paths' in config: 38 | for key, template in path_templates.items(): 39 | if key in config['model_paths']: 40 | # only use template when the original value is empty string 41 | current_value = config['model_paths'][key] 42 | if current_value == "" or current_value is None: 43 | config['model_paths'][key] = template 44 | else: 45 | # if the original value is not empty, use template to replace the {exp_name} placeholder 46 | if isinstance(current_value, str) and '{exp_name}' in current_value: 47 | config['model_paths'][key] = current_value.format(exp_name=exp_name) 48 | 49 | return config 50 | 51 | def get(self, key: str, default=None): 52 | 53 | keys = key.split('.') 54 | value = self.config 55 | 56 | try: 57 | for k in keys: 58 | value = value[k] 59 | return value 60 | except (KeyError, TypeError): 61 | return default 62 | 63 | def get_data_config(self) -> Dict[str, Any]: 64 | return self.config.get('data', {}) 65 | 66 | def get_training_config(self) -> Dict[str, Any]: 67 | return self.config.get('training', {}) 68 | 69 | def get_model_paths(self) -> Dict[str, str]: 70 | return self.config.get('model_paths', {}) 71 | 72 | def get_experiment_config(self) -> Dict[str, Any]: 73 | return self.config.get('experiment', {}) 74 | 75 | def get_device_config(self) -> Dict[str, Any]: 76 | return self.config.get('device', {}) 77 | 78 | def get_distributed_config(self) -> Dict[str, Any]: 79 | return self.config.get('distributed', {}) 80 | 81 | def update_config(self, updates: Dict[str, Any]): 82 | 83 | def update_nested_dict(d, u): 84 | for k, v in u.items(): 85 | if isinstance(v, dict): 86 | d[k] = update_nested_dict(d.get(k, {}), v) 87 | else: 88 | d[k] = v 89 | return d 90 | 91 | self.config = update_nested_dict(self.config, updates) 92 | 93 | def save_config(self, save_path: str = None): 94 | 95 | if save_path is None: 96 | save_path = self.config_path 97 | 98 | with open(save_path, 'w', encoding='utf-8') as f: 99 | yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True, indent=2) 100 | 101 | def print_config(self): 102 | print("=" * 50) 103 | print("Current configuration:") 104 | print("=" * 50) 105 | yaml.dump(self.config, default_flow_style=False, allow_unicode=True, indent=2) 106 | print("=" * 50) 107 | 108 | 109 | class CustomFinetuneConfig: 110 | 111 | def __init__(self, config_path: str = None): 112 | 113 | if config_path is None: 114 | config_path = os.path.join(os.path.dirname(__file__), 'config.yaml') 115 | 116 | self.loader = ConfigLoader(config_path) 117 | self._load_all_configs() 118 | 119 | def _load_all_configs(self): 120 | 121 | data_config = self.loader.get_data_config() 122 | self.data_path = data_config.get('data_path') 123 | self.lookback_window = data_config.get('lookback_window', 512) 124 | self.predict_window = data_config.get('predict_window', 48) 125 | self.max_context = data_config.get('max_context', 512) 126 | self.clip = data_config.get('clip', 5.0) 127 | self.train_ratio = data_config.get('train_ratio', 0.9) 128 | self.val_ratio = data_config.get('val_ratio', 0.1) 129 | self.test_ratio = data_config.get('test_ratio', 0.0) 130 | 131 | # training configuration 132 | training_config = self.loader.get_training_config() 133 | # support training epochs of tokenizer and basemodel separately 134 | self.tokenizer_epochs = training_config.get('tokenizer_epochs', 30) 135 | self.basemodel_epochs = training_config.get('basemodel_epochs', 30) 136 | 137 | if 'epochs' in training_config and 'tokenizer_epochs' not in training_config: 138 | self.tokenizer_epochs = training_config.get('epochs', 30) 139 | if 'epochs' in training_config and 'basemodel_epochs' not in training_config: 140 | self.basemodel_epochs = training_config.get('epochs', 30) 141 | 142 | self.batch_size = training_config.get('batch_size', 160) 143 | self.log_interval = training_config.get('log_interval', 50) 144 | self.num_workers = training_config.get('num_workers', 6) 145 | self.seed = training_config.get('seed', 100) 146 | self.tokenizer_learning_rate = training_config.get('tokenizer_learning_rate', 2e-4) 147 | self.predictor_learning_rate = training_config.get('predictor_learning_rate', 4e-5) 148 | self.adam_beta1 = training_config.get('adam_beta1', 0.9) 149 | self.adam_beta2 = training_config.get('adam_beta2', 0.95) 150 | self.adam_weight_decay = training_config.get('adam_weight_decay', 0.1) 151 | self.accumulation_steps = training_config.get('accumulation_steps', 1) 152 | 153 | model_paths = self.loader.get_model_paths() 154 | self.exp_name = model_paths.get('exp_name', 'default_experiment') 155 | self.pretrained_tokenizer_path = model_paths.get('pretrained_tokenizer') 156 | self.pretrained_predictor_path = model_paths.get('pretrained_predictor') 157 | self.base_save_path = model_paths.get('base_save_path') 158 | self.tokenizer_save_name = model_paths.get('tokenizer_save_name', 'tokenizer') 159 | self.basemodel_save_name = model_paths.get('basemodel_save_name', 'basemodel') 160 | self.finetuned_tokenizer_path = model_paths.get('finetuned_tokenizer') 161 | 162 | experiment_config = self.loader.get_experiment_config() 163 | self.experiment_name = experiment_config.get('name', 'kronos_custom_finetune') 164 | self.experiment_description = experiment_config.get('description', '') 165 | self.use_comet = experiment_config.get('use_comet', False) 166 | self.train_tokenizer = experiment_config.get('train_tokenizer', True) 167 | self.train_basemodel = experiment_config.get('train_basemodel', True) 168 | self.skip_existing = experiment_config.get('skip_existing', False) 169 | 170 | unified_pretrained = experiment_config.get('pre_trained', None) 171 | self.pre_trained_tokenizer = experiment_config.get('pre_trained_tokenizer', unified_pretrained if unified_pretrained is not None else True) 172 | self.pre_trained_predictor = experiment_config.get('pre_trained_predictor', unified_pretrained if unified_pretrained is not None else True) 173 | 174 | device_config = self.loader.get_device_config() 175 | self.use_cuda = device_config.get('use_cuda', True) 176 | self.device_id = device_config.get('device_id', 0) 177 | 178 | distributed_config = self.loader.get_distributed_config() 179 | self.use_ddp = distributed_config.get('use_ddp', False) 180 | self.ddp_backend = distributed_config.get('backend', 'nccl') 181 | 182 | self._compute_full_paths() 183 | 184 | def _compute_full_paths(self): 185 | 186 | self.tokenizer_save_path = os.path.join(self.base_save_path, self.tokenizer_save_name) 187 | self.tokenizer_best_model_path = os.path.join(self.tokenizer_save_path, 'best_model') 188 | 189 | self.basemodel_save_path = os.path.join(self.base_save_path, self.basemodel_save_name) 190 | self.basemodel_best_model_path = os.path.join(self.basemodel_save_path, 'best_model') 191 | 192 | def get_tokenizer_config(self): 193 | 194 | return { 195 | 'data_path': self.data_path, 196 | 'lookback_window': self.lookback_window, 197 | 'predict_window': self.predict_window, 198 | 'max_context': self.max_context, 199 | 'clip': self.clip, 200 | 'train_ratio': self.train_ratio, 201 | 'val_ratio': self.val_ratio, 202 | 'test_ratio': self.test_ratio, 203 | 'epochs': self.tokenizer_epochs, 204 | 'batch_size': self.batch_size, 205 | 'log_interval': self.log_interval, 206 | 'num_workers': self.num_workers, 207 | 'seed': self.seed, 208 | 'learning_rate': self.tokenizer_learning_rate, 209 | 'adam_beta1': self.adam_beta1, 210 | 'adam_beta2': self.adam_beta2, 211 | 'adam_weight_decay': self.adam_weight_decay, 212 | 'accumulation_steps': self.accumulation_steps, 213 | 'pretrained_model_path': self.pretrained_tokenizer_path, 214 | 'save_path': self.tokenizer_save_path, 215 | 'use_comet': self.use_comet 216 | } 217 | 218 | def get_basemodel_config(self): 219 | 220 | return { 221 | 'data_path': self.data_path, 222 | 'lookback_window': self.lookback_window, 223 | 'predict_window': self.predict_window, 224 | 'max_context': self.max_context, 225 | 'clip': self.clip, 226 | 'train_ratio': self.train_ratio, 227 | 'val_ratio': self.val_ratio, 228 | 'test_ratio': self.test_ratio, 229 | 'epochs': self.basemodel_epochs, 230 | 'batch_size': self.batch_size, 231 | 'log_interval': self.log_interval, 232 | 'num_workers': self.num_workers, 233 | 'seed': self.seed, 234 | 'predictor_learning_rate': self.predictor_learning_rate, 235 | 'tokenizer_learning_rate': self.tokenizer_learning_rate, 236 | 'adam_beta1': self.adam_beta1, 237 | 'adam_beta2': self.adam_beta2, 238 | 'adam_weight_decay': self.adam_weight_decay, 239 | 'pretrained_tokenizer_path': self.finetuned_tokenizer_path, 240 | 'pretrained_predictor_path': self.pretrained_predictor_path, 241 | 'save_path': self.basemodel_save_path, 242 | 'use_comet': self.use_comet 243 | } 244 | 245 | def print_config_summary(self): 246 | 247 | print("=" * 60) 248 | print("Kronos finetuning configuration summary") 249 | print("=" * 60) 250 | print(f"Experiment name: {self.exp_name}") 251 | print(f"Data path: {self.data_path}") 252 | print(f"Lookback window: {self.lookback_window}") 253 | print(f"Predict window: {self.predict_window}") 254 | print(f"Tokenizer training epochs: {self.tokenizer_epochs}") 255 | print(f"Basemodel training epochs: {self.basemodel_epochs}") 256 | print(f"Batch size: {self.batch_size}") 257 | print(f"Tokenizer learning rate: {self.tokenizer_learning_rate}") 258 | print(f"Predictor learning rate: {self.predictor_learning_rate}") 259 | print(f"Train tokenizer: {self.train_tokenizer}") 260 | print(f"Train basemodel: {self.train_basemodel}") 261 | print(f"Skip existing: {self.skip_existing}") 262 | print(f"Use pre-trained tokenizer: {self.pre_trained_tokenizer}") 263 | print(f"Use pre-trained predictor: {self.pre_trained_predictor}") 264 | print(f"Base save path: {self.base_save_path}") 265 | print(f"Tokenizer save path: {self.tokenizer_save_path}") 266 | print(f"Basemodel save path: {self.basemodel_save_path}") 267 | print("=" * 60) 268 | -------------------------------------------------------------------------------- /finetune_csv/finetune_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | import random 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.distributed import DistributedSampler 11 | from time import gmtime, strftime 12 | import datetime 13 | import logging 14 | from logging.handlers import RotatingFileHandler 15 | import torch.distributed as dist 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | 18 | sys.path.append("../") 19 | from model import KronosTokenizer 20 | from finetune_base_model import CustomKlineDataset 21 | from config_loader import CustomFinetuneConfig 22 | 23 | 24 | def set_seed(seed: int, rank: int = 0): 25 | actual_seed = seed 26 | random.seed(actual_seed) 27 | np.random.seed(actual_seed) 28 | torch.manual_seed(actual_seed) 29 | if torch.cuda.is_available(): 30 | torch.cuda.manual_seed_all(actual_seed) 31 | torch.backends.cudnn.deterministic = True 32 | torch.backends.cudnn.benchmark = False 33 | 34 | 35 | def get_model_size(model: torch.nn.Module) -> str: 36 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 37 | if total_params >= 1e9: 38 | return f"{total_params / 1e9:.1f}B" 39 | elif total_params >= 1e6: 40 | return f"{total_params / 1e6:.1f}M" 41 | else: 42 | return f"{total_params / 1e3:.1f}K" 43 | 44 | 45 | def format_time(seconds: float) -> str: 46 | return str(datetime.timedelta(seconds=int(seconds))) 47 | 48 | 49 | def setup_logging(exp_name: str, log_dir: str, rank: int = 0) -> logging.Logger: 50 | os.makedirs(log_dir, exist_ok=True) 51 | 52 | logger = logging.getLogger(f"tokenizer_training_rank_{rank}") 53 | logger.setLevel(logging.INFO) 54 | 55 | if logger.handlers: 56 | return logger 57 | 58 | log_file = os.path.join(log_dir, f"tokenizer_training_rank_{rank}.log") 59 | file_handler = RotatingFileHandler( 60 | log_file, 61 | maxBytes=10*1024*1024, 62 | backupCount=5, 63 | encoding='utf-8' 64 | ) 65 | file_handler.setLevel(logging.INFO) 66 | 67 | console_handler = None 68 | if rank == 0: 69 | console_handler = logging.StreamHandler() 70 | console_handler.setLevel(logging.INFO) 71 | 72 | formatter = logging.Formatter( 73 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s', 74 | datefmt='%Y-%m-%d %H:%M:%S' 75 | ) 76 | file_handler.setFormatter(formatter) 77 | if console_handler is not None: 78 | console_handler.setFormatter(formatter) 79 | 80 | logger.addHandler(file_handler) 81 | if console_handler is not None: 82 | logger.addHandler(console_handler) 83 | 84 | logger.info(f"=== Tokenizer Training Started ===") 85 | logger.info(f"Experiment Name: {exp_name}") 86 | logger.info(f"Log Directory: {log_dir}") 87 | logger.info(f"Rank: {rank}") 88 | logger.info(f"Timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") 89 | 90 | return logger 91 | 92 | 93 | def create_dataloaders(config): 94 | if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: 95 | print("Creating tokenizer training data loaders...") 96 | 97 | train_dataset = CustomKlineDataset( 98 | data_path=config.data_path, 99 | data_type="train", 100 | lookback_window=config.lookback_window, 101 | predict_window=config.predict_window, 102 | clip=config.clip, 103 | seed=config.seed, 104 | train_ratio=config.train_ratio, 105 | val_ratio=config.val_ratio, 106 | test_ratio=config.test_ratio 107 | ) 108 | 109 | val_dataset = CustomKlineDataset( 110 | data_path=config.data_path, 111 | data_type="val", 112 | lookback_window=config.lookback_window, 113 | predict_window=config.predict_window, 114 | clip=config.clip, 115 | seed=config.seed + 1, 116 | train_ratio=config.train_ratio, 117 | val_ratio=config.val_ratio, 118 | test_ratio=config.test_ratio 119 | ) 120 | 121 | use_ddp = dist.is_available() and dist.is_initialized() 122 | train_sampler = DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) if use_ddp else None 123 | val_sampler = DistributedSampler(val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, drop_last=False) if use_ddp else None 124 | 125 | train_loader = DataLoader( 126 | train_dataset, 127 | batch_size=config.batch_size, 128 | shuffle=(train_sampler is None), 129 | num_workers=config.num_workers, 130 | pin_memory=True, 131 | drop_last=True, 132 | sampler=train_sampler 133 | ) 134 | 135 | val_loader = DataLoader( 136 | val_dataset, 137 | batch_size=config.batch_size, 138 | shuffle=False, 139 | num_workers=config.num_workers, 140 | pin_memory=True, 141 | drop_last=False, 142 | sampler=val_sampler 143 | ) 144 | 145 | if not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0: 146 | print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}") 147 | 148 | return train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler 149 | 150 | 151 | def train_tokenizer(model, device, config, save_dir, logger): 152 | logger.info("Starting tokenizer training...") 153 | use_ddp = dist.is_available() and dist.is_initialized() 154 | rank = dist.get_rank() if use_ddp else 0 155 | world_size = dist.get_world_size() if use_ddp else 1 156 | 157 | train_loader, val_loader, train_dataset, val_dataset, train_sampler, val_sampler = create_dataloaders(config) 158 | 159 | optimizer = torch.optim.AdamW( 160 | model.parameters(), 161 | lr=config.tokenizer_learning_rate, 162 | weight_decay=config.adam_weight_decay 163 | ) 164 | 165 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 166 | optimizer, 167 | max_lr=config.tokenizer_learning_rate, 168 | steps_per_epoch=len(train_loader), 169 | epochs=config.tokenizer_epochs, 170 | pct_start=0.03, 171 | div_factor=10 172 | ) 173 | 174 | if use_ddp: 175 | local_rank = int(os.environ.get("LOCAL_RANK", "0")) 176 | model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) 177 | 178 | best_val_loss = float("inf") 179 | batch_idx_global = 0 180 | 181 | accumulation_steps = getattr(config, 'accumulation_steps', 1) 182 | 183 | for epoch in range(config.tokenizer_epochs): 184 | epoch_start_time = time.time() 185 | model.train() 186 | 187 | train_dataset.set_epoch_seed(epoch * 10000) 188 | val_dataset.set_epoch_seed(0) 189 | if train_sampler is not None: 190 | train_sampler.set_epoch(epoch) 191 | 192 | for batch_idx, (ori_batch_x, _) in enumerate(train_loader): 193 | ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) 194 | 195 | current_batch_total_loss = 0.0 196 | for j in range(accumulation_steps): 197 | start_idx = j * (ori_batch_x.shape[0] // accumulation_steps) 198 | end_idx = (j + 1) * (ori_batch_x.shape[0] // accumulation_steps) 199 | batch_x = ori_batch_x[start_idx:end_idx] 200 | 201 | zs, bsq_loss, _, _ = (model.module if use_ddp else model)(batch_x) 202 | z_pre, z = zs 203 | 204 | recon_loss_pre = F.mse_loss(z_pre, batch_x) 205 | recon_loss_all = F.mse_loss(z, batch_x) 206 | recon_loss = recon_loss_pre + recon_loss_all 207 | loss = (recon_loss + bsq_loss) / 2 208 | 209 | loss_scaled = loss / accumulation_steps 210 | current_batch_total_loss += loss.item() 211 | loss_scaled.backward() 212 | 213 | torch.nn.utils.clip_grad_norm_((model.module if use_ddp else model).parameters(), max_norm=2.0) 214 | optimizer.step() 215 | scheduler.step() 216 | optimizer.zero_grad() 217 | 218 | if (batch_idx_global + 1) % config.log_interval == 0: 219 | avg_loss = current_batch_total_loss / accumulation_steps 220 | lr = optimizer.param_groups[0]["lr"] 221 | log_msg = (f"[Epoch {epoch+1}/{config.tokenizer_epochs}, Step {batch_idx+1}/{len(train_loader)}] " 222 | f"LR: {lr:.6f}, Loss: {avg_loss:.4f}") 223 | logger.info(log_msg) 224 | if rank == 0: 225 | print(log_msg) 226 | 227 | detail_msg = (f" - VQ Loss: {bsq_loss.item():.4f}\n" 228 | f" - Recon Loss Pre: {recon_loss_pre.item():.4f}\n" 229 | f" - Recon Loss All: {recon_loss_all.item():.4f}") 230 | logger.info(detail_msg) 231 | if rank == 0: 232 | print(detail_msg) 233 | 234 | batch_idx_global += 1 235 | 236 | model.eval() 237 | tot_val_loss_sum_rank = 0.0 238 | val_sample_count_rank = 0 239 | 240 | with torch.no_grad(): 241 | for ori_batch_x, _ in val_loader: 242 | ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) 243 | zs, _, _, _ = (model.module if use_ddp else model)(ori_batch_x) 244 | _, z = zs 245 | val_loss_item = F.mse_loss(z, ori_batch_x) 246 | 247 | tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0) 248 | val_sample_count_rank += ori_batch_x.size(0) 249 | 250 | if use_ddp: 251 | tensor_sum = torch.tensor([tot_val_loss_sum_rank, val_sample_count_rank], dtype=torch.float64, device=device) 252 | dist.all_reduce(tensor_sum, op=dist.ReduceOp.SUM) 253 | tot_val_loss_all = tensor_sum[0].item() 254 | val_count_all = int(tensor_sum[1].item()) 255 | avg_val_loss = (tot_val_loss_all / val_count_all) if val_count_all > 0 else 0.0 256 | else: 257 | avg_val_loss = tot_val_loss_sum_rank / val_sample_count_rank if val_sample_count_rank > 0 else 0 258 | 259 | epoch_time = time.time() - epoch_start_time 260 | epoch_summary = (f"\n--- Epoch {epoch+1}/{config.tokenizer_epochs} Summary ---\n" 261 | f"Validation Loss: {avg_val_loss:.4f}\n" 262 | f"Epoch Time: {format_time(epoch_time)}\n" 263 | f"Total Training Time: {format_time(time.time() - epoch_start_time)}\n") 264 | logger.info(epoch_summary) 265 | if rank == 0: 266 | print(epoch_summary) 267 | 268 | if avg_val_loss < best_val_loss: 269 | best_val_loss = avg_val_loss 270 | if rank == 0: 271 | model_save_path = os.path.join(save_dir, "best_model") 272 | os.makedirs(model_save_path, exist_ok=True) 273 | (model.module if use_ddp else model).save_pretrained(model_save_path) 274 | save_msg = f"Best model saved to: {model_save_path} (validation loss: {best_val_loss:.4f})" 275 | logger.info(save_msg) 276 | print(save_msg) 277 | 278 | return best_val_loss 279 | 280 | 281 | def main(): 282 | import argparse 283 | 284 | parser = argparse.ArgumentParser(description='Kronos Tokenizer Fine-tuning Training') 285 | parser.add_argument('--config', type=str, default='config.yaml', 286 | help='Configuration file path (default: config.yaml)') 287 | args = parser.parse_args() 288 | 289 | config = CustomFinetuneConfig(args.config) 290 | 291 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 292 | print(f"Using device: {device}") 293 | 294 | config = CustomFinetuneConfig(args.config) 295 | 296 | os.makedirs(config.tokenizer_save_path, exist_ok=True) 297 | 298 | log_dir = os.path.join(config.base_save_path, "logs") 299 | logger = setup_logging(config.exp_name, log_dir, 0) 300 | 301 | set_seed(config.seed) 302 | 303 | # 加载预训练tokenizer 304 | if getattr(config, 'pre_trained_tokenizer', True): 305 | logger.info("Loading pretrained tokenizer...") 306 | print("Loading pretrained tokenizer...") 307 | tokenizer = KronosTokenizer.from_pretrained(config.pretrained_tokenizer_path) 308 | else: 309 | print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture") 310 | import json, os 311 | cfg_path = os.path.join(config.pretrained_tokenizer_path, 'config.json') 312 | with open(cfg_path, 'r') as f: 313 | arch = json.load(f) 314 | tokenizer = KronosTokenizer( 315 | d_in=arch.get('d_in', 6), 316 | d_model=arch.get('d_model', 256), 317 | n_heads=arch.get('n_heads', 4), 318 | ff_dim=arch.get('ff_dim', 512), 319 | n_enc_layers=arch.get('n_enc_layers', 4), 320 | n_dec_layers=arch.get('n_dec_layers', 4), 321 | ffn_dropout_p=arch.get('ffn_dropout_p', 0.0), 322 | attn_dropout_p=arch.get('attn_dropout_p', 0.0), 323 | resid_dropout_p=arch.get('resid_dropout_p', 0.0), 324 | s1_bits=arch.get('s1_bits', 10), 325 | s2_bits=arch.get('s2_bits', 10), 326 | beta=arch.get('beta', 0.05), 327 | gamma0=arch.get('gamma0', 1.0), 328 | gamma=arch.get('gamma', 1.1), 329 | zeta=arch.get('zeta', 0.05), 330 | group_size=arch.get('group_size', 4) 331 | ) 332 | tokenizer = tokenizer.to(device) 333 | 334 | model_size = get_model_size(tokenizer) 335 | logger.info(f"Tokenizer parameters: {model_size}") 336 | print(f"Tokenizer parameters: {model_size}") 337 | 338 | logger.info("=== Training Configuration ===") 339 | logger.info(f"Data path: {config.data_path}") 340 | logger.info(f"Lookback window: {config.lookback_window}") 341 | logger.info(f"Predict window: {config.predict_window}") 342 | logger.info(f"Batch size: {config.batch_size}") 343 | logger.info(f"Learning rate: {config.tokenizer_learning_rate}") 344 | logger.info(f"Training epochs: {config.tokenizer_epochs}") 345 | logger.info(f"Device: {device}") 346 | logger.info(f"Distributed training: False") 347 | 348 | logger.info("Starting tokenizer fine-tuning training...") 349 | print("Starting tokenizer fine-tuning training...") 350 | best_val_loss = train_tokenizer(tokenizer, device, config, config.tokenizer_save_path, logger) 351 | 352 | final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nModel saved to: {config.tokenizer_save_path}" 353 | logger.info(final_msg) 354 | print(final_msg) 355 | 356 | 357 | if __name__ == "__main__": 358 | main() 359 | 360 | -------------------------------------------------------------------------------- /finetune/qlib_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import pickle 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader 11 | from tqdm import trange, tqdm 12 | from matplotlib import pyplot as plt 13 | 14 | import qlib 15 | from qlib.config import REG_CN 16 | from qlib.backtest import backtest, executor, CommonInfrastructure 17 | from qlib.contrib.evaluate import risk_analysis 18 | from qlib.contrib.strategy import TopkDropoutStrategy 19 | from qlib.utils import flatten_dict 20 | from qlib.utils.time import Freq 21 | 22 | # Ensure project root is in the Python path 23 | sys.path.append("../") 24 | from config import Config 25 | from model.kronos import Kronos, KronosTokenizer, auto_regressive_inference 26 | 27 | 28 | # ================================================================================= 29 | # 1. Data Loading and Processing for Inference 30 | # ================================================================================= 31 | 32 | class QlibTestDataset(Dataset): 33 | """ 34 | PyTorch Dataset for handling Qlib test data, specifically for inference. 35 | 36 | This dataset iterates through all possible sliding windows sequentially. It also 37 | yields metadata like symbol and timestamp, which are crucial for mapping 38 | predictions back to the original time series. 39 | """ 40 | 41 | def __init__(self, data: dict, config: Config): 42 | self.data = data 43 | self.config = config 44 | self.window_size = config.lookback_window + config.predict_window 45 | self.symbols = list(self.data.keys()) 46 | self.feature_list = config.feature_list 47 | self.time_feature_list = config.time_feature_list 48 | self.indices = [] 49 | 50 | print("Preprocessing and building indices for test dataset...") 51 | for symbol in self.symbols: 52 | df = self.data[symbol].reset_index() 53 | # Generate time features on-the-fly 54 | df['minute'] = df['datetime'].dt.minute 55 | df['hour'] = df['datetime'].dt.hour 56 | df['weekday'] = df['datetime'].dt.weekday 57 | df['day'] = df['datetime'].dt.day 58 | df['month'] = df['datetime'].dt.month 59 | self.data[symbol] = df # Store preprocessed dataframe 60 | 61 | num_samples = len(df) - self.window_size + 1 62 | if num_samples > 0: 63 | for i in range(num_samples): 64 | timestamp = df.iloc[i + self.config.lookback_window - 1]['datetime'] 65 | self.indices.append((symbol, i, timestamp)) 66 | 67 | def __len__(self) -> int: 68 | return len(self.indices) 69 | 70 | def __getitem__(self, idx: int): 71 | symbol, start_idx, timestamp = self.indices[idx] 72 | df = self.data[symbol] 73 | 74 | context_end = start_idx + self.config.lookback_window 75 | predict_end = context_end + self.config.predict_window 76 | 77 | context_df = df.iloc[start_idx:context_end] 78 | predict_df = df.iloc[context_end:predict_end] 79 | 80 | x = context_df[self.feature_list].values.astype(np.float32) 81 | x_stamp = context_df[self.time_feature_list].values.astype(np.float32) 82 | y_stamp = predict_df[self.time_feature_list].values.astype(np.float32) 83 | 84 | # Instance-level normalization, consistent with training 85 | x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) 86 | x = (x - x_mean) / (x_std + 1e-5) 87 | x = np.clip(x, -self.config.clip, self.config.clip) 88 | 89 | return torch.from_numpy(x), torch.from_numpy(x_stamp), torch.from_numpy(y_stamp), symbol, timestamp 90 | 91 | 92 | # ================================================================================= 93 | # 2. Backtesting Logic 94 | # ================================================================================= 95 | 96 | class QlibBacktest: 97 | """ 98 | A wrapper class for conducting backtesting experiments using Qlib. 99 | """ 100 | 101 | def __init__(self, config: Config): 102 | self.config = config 103 | self.initialize_qlib() 104 | 105 | def initialize_qlib(self): 106 | """Initializes the Qlib environment.""" 107 | print("Initializing Qlib for backtesting...") 108 | qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN) 109 | 110 | def run_single_backtest(self, signal_series: pd.Series) -> pd.DataFrame: 111 | """ 112 | Runs a single backtest for a given prediction signal. 113 | 114 | Args: 115 | signal_series (pd.Series): A pandas Series with a MultiIndex 116 | (instrument, datetime) and prediction scores. 117 | Returns: 118 | pd.DataFrame: A DataFrame containing the performance report. 119 | """ 120 | strategy = TopkDropoutStrategy( 121 | topk=self.config.backtest_n_symbol_hold, 122 | n_drop=self.config.backtest_n_symbol_drop, 123 | hold_thresh=self.config.backtest_hold_thresh, 124 | signal=signal_series, 125 | ) 126 | executor_config = { 127 | "time_per_step": "day", 128 | "generate_portfolio_metrics": True, 129 | "delay_execution": True, 130 | } 131 | backtest_config = { 132 | "start_time": self.config.backtest_time_range[0], 133 | "end_time": self.config.backtest_time_range[1], 134 | "account": 100_000_000, 135 | "benchmark": self.config.backtest_benchmark, 136 | "exchange_kwargs": { 137 | "freq": "day", "limit_threshold": 0.095, "deal_price": "open", 138 | "open_cost": 0.001, "close_cost": 0.0015, "min_cost": 5, 139 | }, 140 | "executor": executor.SimulatorExecutor(**executor_config), 141 | } 142 | 143 | portfolio_metric_dict, _ = backtest(strategy=strategy, **backtest_config) 144 | analysis_freq = "{0}{1}".format(*Freq.parse("day")) 145 | report, _ = portfolio_metric_dict.get(analysis_freq) 146 | 147 | # --- Analysis and Reporting --- 148 | analysis = { 149 | "excess_return_without_cost": risk_analysis(report["return"] - report["bench"], freq=analysis_freq), 150 | "excess_return_with_cost": risk_analysis(report["return"] - report["bench"] - report["cost"], freq=analysis_freq), 151 | } 152 | print("\n--- Backtest Analysis ---") 153 | print("Benchmark Return:", risk_analysis(report["bench"], freq=analysis_freq), sep='\n') 154 | print("\nExcess Return (w/o cost):", analysis["excess_return_without_cost"], sep='\n') 155 | print("\nExcess Return (w/ cost):", analysis["excess_return_with_cost"], sep='\n') 156 | 157 | report_df = pd.DataFrame({ 158 | "cum_bench": report["bench"].cumsum(), 159 | "cum_return_w_cost": (report["return"] - report["cost"]).cumsum(), 160 | "cum_ex_return_w_cost": (report["return"] - report["bench"] - report["cost"]).cumsum(), 161 | }) 162 | return report_df 163 | 164 | def run_and_plot_results(self, signals: dict[str, pd.DataFrame]): 165 | """ 166 | Runs backtests for multiple signals and plots the cumulative return curves. 167 | 168 | Args: 169 | signals (dict[str, pd.DataFrame]): A dictionary where keys are signal names 170 | and values are prediction DataFrames. 171 | """ 172 | return_df, ex_return_df, bench_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame() 173 | 174 | for signal_name, pred_df in signals.items(): 175 | print(f"\nBacktesting signal: {signal_name}...") 176 | pred_series = pred_df.stack() 177 | pred_series.index.names = ['datetime', 'instrument'] 178 | pred_series = pred_series.swaplevel().sort_index() 179 | report_df = self.run_single_backtest(pred_series) 180 | 181 | return_df[signal_name] = report_df['cum_return_w_cost'] 182 | ex_return_df[signal_name] = report_df['cum_ex_return_w_cost'] 183 | if 'return' not in bench_df: 184 | bench_df['return'] = report_df['cum_bench'] 185 | 186 | # Plotting results 187 | fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True) 188 | return_df.plot(ax=axes[0], title='Cumulative Return with Cost', grid=True) 189 | axes[0].plot(bench_df['return'], label=self.config.instrument.upper(), color='black', linestyle='--') 190 | axes[0].legend() 191 | axes[0].set_ylabel("Cumulative Return") 192 | 193 | ex_return_df.plot(ax=axes[1], title='Cumulative Excess Return with Cost', grid=True) 194 | axes[1].legend() 195 | axes[1].set_xlabel("Date") 196 | axes[1].set_ylabel("Cumulative Excess Return") 197 | 198 | plt.tight_layout() 199 | plt.savefig("../figures/backtest_result_example.png", dpi=200) 200 | plt.show() 201 | 202 | 203 | # ================================================================================= 204 | # 3. Inference Logic 205 | # ================================================================================= 206 | 207 | def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]: 208 | """Loads the fine-tuned tokenizer and predictor model.""" 209 | device = torch.device(config['device']) 210 | print(f"Loading models onto device: {device}...") 211 | tokenizer = KronosTokenizer.from_pretrained(config['tokenizer_path']).to(device).eval() 212 | model = Kronos.from_pretrained(config['model_path']).to(device).eval() 213 | return tokenizer, model 214 | 215 | 216 | def collate_fn_for_inference(batch): 217 | """ 218 | Custom collate function to handle batches containing Tensors, strings, and Timestamps. 219 | 220 | Args: 221 | batch (list): A list of samples, where each sample is the tuple returned by 222 | QlibTestDataset.__getitem__. 223 | 224 | Returns: 225 | A single tuple containing the batched data. 226 | """ 227 | # Unzip the list of samples into separate lists for each data type 228 | x, x_stamp, y_stamp, symbols, timestamps = zip(*batch) 229 | 230 | # Stack the tensors to create a batch 231 | x_batch = torch.stack(x, dim=0) 232 | x_stamp_batch = torch.stack(x_stamp, dim=0) 233 | y_stamp_batch = torch.stack(y_stamp, dim=0) 234 | 235 | # Return the strings and timestamps as lists 236 | return x_batch, x_stamp_batch, y_stamp_batch, list(symbols), list(timestamps) 237 | 238 | 239 | def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFrame]: 240 | """ 241 | Runs inference on the test dataset to generate prediction signals. 242 | 243 | Args: 244 | config (dict): A dictionary containing inference parameters. 245 | test_data (dict): The raw test data loaded from a pickle file. 246 | 247 | Returns: 248 | A dictionary where keys are signal types (e.g., 'mean', 'last') and 249 | values are DataFrames of predictions (datetime index, symbol columns). 250 | """ 251 | tokenizer, model = load_models(config) 252 | device = torch.device(config['device']) 253 | 254 | # Use the Dataset and DataLoader for efficient batching and processing 255 | dataset = QlibTestDataset(data=test_data, config=Config()) 256 | loader = DataLoader( 257 | dataset, 258 | batch_size=config['batch_size'] // config['sample_count'], 259 | shuffle=False, 260 | num_workers=os.cpu_count() // 2, 261 | collate_fn=collate_fn_for_inference 262 | ) 263 | 264 | results = defaultdict(list) 265 | with torch.no_grad(): 266 | for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"): 267 | preds = auto_regressive_inference( 268 | tokenizer, model, x.to(device), x_stamp.to(device), y_stamp.to(device), 269 | max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'], 270 | T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count'] 271 | ) 272 | # You can try commenting on this line to keep the history data 273 | preds = preds[:, -config['pred_len']:, :] 274 | 275 | # The 'close' price is at index 3 in `feature_list` 276 | last_day_close = x[:, -1, 3].numpy() 277 | signals = { 278 | 'last': preds[:, -1, 3] - last_day_close, 279 | 'mean': np.mean(preds[:, :, 3], axis=1) - last_day_close, 280 | 'max': np.max(preds[:, :, 3], axis=1) - last_day_close, 281 | 'min': np.min(preds[:, :, 3], axis=1) - last_day_close, 282 | } 283 | 284 | for i in range(len(symbols)): 285 | for sig_type, sig_values in signals.items(): 286 | results[sig_type].append((timestamps[i], symbols[i], sig_values[i])) 287 | 288 | print("Post-processing predictions into DataFrames...") 289 | prediction_dfs = {} 290 | for sig_type, records in results.items(): 291 | df = pd.DataFrame(records, columns=['datetime', 'instrument', 'score']) 292 | pivot_df = df.pivot_table(index='datetime', columns='instrument', values='score') 293 | prediction_dfs[sig_type] = pivot_df.sort_index() 294 | 295 | return prediction_dfs 296 | 297 | 298 | # ================================================================================= 299 | # 4. Main Execution 300 | # ================================================================================= 301 | 302 | def main(): 303 | """Main function to set up config, run inference, and execute backtesting.""" 304 | parser = argparse.ArgumentParser(description="Run Kronos Inference and Backtesting") 305 | parser.add_argument("--device", type=str, default="cuda:1", help="Device for inference (e.g., 'cuda:0', 'cpu')") 306 | args = parser.parse_args() 307 | 308 | # --- 1. Configuration Setup --- 309 | base_config = Config() 310 | 311 | # Create a dedicated dictionary for this run's configuration 312 | run_config = { 313 | 'device': args.device, 314 | 'data_path': base_config.dataset_path, 315 | 'result_save_path': base_config.backtest_result_path, 316 | 'result_name': base_config.backtest_save_folder_name, 317 | 'tokenizer_path': base_config.finetuned_tokenizer_path, 318 | 'model_path': base_config.finetuned_predictor_path, 319 | 'max_context': base_config.max_context, 320 | 'pred_len': base_config.predict_window, 321 | 'clip': base_config.clip, 322 | 'T': base_config.inference_T, 323 | 'top_k': base_config.inference_top_k, 324 | 'top_p': base_config.inference_top_p, 325 | 'sample_count': base_config.inference_sample_count, 326 | 'batch_size': base_config.backtest_batch_size, 327 | } 328 | 329 | print("--- Running with Configuration ---") 330 | for key, val in run_config.items(): 331 | print(f"{key:>20}: {val}") 332 | print("-" * 35) 333 | 334 | # --- 2. Load Data --- 335 | test_data_path = os.path.join(run_config['data_path'], "test_data.pkl") 336 | print(f"Loading test data from {test_data_path}...") 337 | with open(test_data_path, 'rb') as f: 338 | test_data = pickle.load(f) 339 | print(test_data) 340 | # --- 3. Generate Predictions --- 341 | model_preds = generate_predictions(run_config, test_data) 342 | 343 | # --- 4. Save Predictions --- 344 | save_dir = os.path.join(run_config['result_save_path'], run_config['result_name']) 345 | os.makedirs(save_dir, exist_ok=True) 346 | predictions_file = os.path.join(save_dir, "predictions.pkl") 347 | print(f"Saving prediction signals to {predictions_file}...") 348 | with open(predictions_file, 'wb') as f: 349 | pickle.dump(model_preds, f) 350 | 351 | # --- 5. Run Backtesting --- 352 | with open(predictions_file, 'rb') as f: 353 | model_preds = pickle.load(f) 354 | 355 | backtester = QlibBacktest(base_config) 356 | backtester.run_and_plot_results(model_preds) 357 | 358 | 359 | if __name__ == '__main__': 360 | main() 361 | 362 | 363 | -------------------------------------------------------------------------------- /finetune_csv/train_sequential.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | import torch.distributed as dist 9 | 10 | sys.path.append('../') 11 | from model import Kronos, KronosTokenizer, KronosPredictor 12 | 13 | from config_loader import CustomFinetuneConfig 14 | from finetune_tokenizer import train_tokenizer, set_seed, setup_logging as setup_tokenizer_logging 15 | from finetune_base_model import train_model, create_dataloaders, setup_logging as setup_basemodel_logging 16 | 17 | 18 | class SequentialTrainer: 19 | 20 | def __init__(self, config_path: str = None): 21 | self.config = CustomFinetuneConfig(config_path) 22 | self.rank = int(os.environ.get("RANK", "0")) 23 | self.world_size = int(os.environ.get("WORLD_SIZE", "1")) 24 | self.local_rank = int(os.environ.get("LOCAL_RANK", str(self.config.device_id if hasattr(self.config, 'device_id') else 0))) 25 | self.device = self._setup_device() 26 | 27 | self.config.print_config_summary() 28 | 29 | def _setup_device(self): 30 | if self.config.use_cuda and torch.cuda.is_available(): 31 | torch.cuda.set_device(self.local_rank) 32 | device = torch.device(f"cuda:{self.local_rank}") 33 | else: 34 | device = torch.device("cpu") 35 | 36 | if self.rank == 0: 37 | print(f"Using device: {device} (rank={self.rank}, world_size={self.world_size}, local_rank={self.local_rank})") 38 | return device 39 | 40 | def _setup_distributed(self): 41 | if self.world_size > 1 and torch.cuda.is_available(): 42 | backend = os.environ.get("DIST_BACKEND", "nccl").lower() 43 | if not dist.is_initialized(): 44 | dist.init_process_group(backend=backend) 45 | if self.rank == 0: 46 | print(f"Distributed training initialized: backend={backend}, world_size={self.world_size}") 47 | else: 48 | if self.rank == 0: 49 | print("Distributed training not enabled, using single GPU/CPU training") 50 | 51 | def _check_existing_models(self): 52 | tokenizer_exists = os.path.exists(self.config.tokenizer_best_model_path) 53 | basemodel_exists = os.path.exists(self.config.basemodel_best_model_path) 54 | 55 | print(f"Tokenizer model exists: {tokenizer_exists}") 56 | print(f"Basemodel model exists: {basemodel_exists}") 57 | 58 | return tokenizer_exists, basemodel_exists 59 | 60 | def _create_directories(self): 61 | os.makedirs(self.config.tokenizer_save_path, exist_ok=True) 62 | os.makedirs(self.config.basemodel_save_path, exist_ok=True) 63 | print(f"Created directory: {self.config.tokenizer_save_path}") 64 | print(f"Created directory: {self.config.basemodel_save_path}") 65 | 66 | def train_tokenizer_phase(self): 67 | print("\n" + "="*60) 68 | print("Starting Tokenizer Fine-tuning Phase") 69 | print("="*60) 70 | 71 | tokenizer_exists, _ = self._check_existing_models() 72 | if tokenizer_exists and self.config.skip_existing: 73 | print("Tokenizer model already exists, skipping training") 74 | return True 75 | 76 | log_dir = os.path.join(self.config.base_save_path, "logs") 77 | logger = setup_tokenizer_logging(self.config.exp_name, log_dir, self.rank) 78 | 79 | set_seed(self.config.seed) 80 | 81 | if getattr(self.config, 'pre_trained_tokenizer', True): 82 | logger.info("Loading pretrained tokenizer...") 83 | if self.rank == 0: 84 | print("Loading pretrained tokenizer...") 85 | tokenizer = KronosTokenizer.from_pretrained(self.config.pretrained_tokenizer_path) 86 | else: 87 | if self.rank == 0: 88 | print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture") 89 | import json 90 | cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json') 91 | with open(cfg_path, 'r') as f: 92 | arch = json.load(f) 93 | tokenizer = KronosTokenizer( 94 | d_in=arch.get('d_in', 6), 95 | d_model=arch.get('d_model', 256), 96 | n_heads=arch.get('n_heads', 4), 97 | ff_dim=arch.get('ff_dim', 512), 98 | n_enc_layers=arch.get('n_enc_layers', 4), 99 | n_dec_layers=arch.get('n_dec_layers', 4), 100 | ffn_dropout_p=arch.get('ffn_dropout_p', 0.0), 101 | attn_dropout_p=arch.get('attn_dropout_p', 0.0), 102 | resid_dropout_p=arch.get('resid_dropout_p', 0.0), 103 | s1_bits=arch.get('s1_bits', 10), 104 | s2_bits=arch.get('s2_bits', 10), 105 | beta=arch.get('beta', 0.05), 106 | gamma0=arch.get('gamma0', 1.0), 107 | gamma=arch.get('gamma', 1.1), 108 | zeta=arch.get('zeta', 0.05), 109 | group_size=arch.get('group_size', 4) 110 | ) 111 | tokenizer = tokenizer.to(self.device) 112 | 113 | model_size = sum(p.numel() for p in tokenizer.parameters()) 114 | logger.info(f"Tokenizer parameters: {model_size:,}") 115 | if self.rank == 0: 116 | print(f"Tokenizer parameters: {model_size:,}") 117 | 118 | logger.info("=== Training Configuration ===") 119 | logger.info(f"Data path: {self.config.data_path}") 120 | logger.info(f"Lookback window: {self.config.lookback_window}") 121 | logger.info(f"Predict window: {self.config.predict_window}") 122 | logger.info(f"Batch size: {self.config.batch_size}") 123 | logger.info(f"Learning rate: {self.config.tokenizer_learning_rate}") 124 | logger.info(f"Training epochs: {self.config.tokenizer_epochs}") 125 | logger.info(f"Device: {self.device}") 126 | logger.info(f"Distributed training: False") 127 | 128 | logger.info("Starting tokenizer fine-tuning training...") 129 | if self.rank == 0: 130 | print("Starting tokenizer fine-tuning training...") 131 | start_time = time.time() 132 | best_val_loss = train_tokenizer( 133 | tokenizer, 134 | self.device, 135 | self.config, 136 | self.config.tokenizer_save_path, 137 | logger, 138 | ) 139 | training_time = time.time() - start_time 140 | 141 | final_msg = f"Tokenizer training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.tokenizer_save_path}" 142 | logger.info(final_msg) 143 | if self.rank == 0: 144 | print(f"\n{final_msg}") 145 | 146 | return True 147 | 148 | def train_basemodel_phase(self): 149 | print("\n" + "="*60) 150 | print("Starting Basemodel Fine-tuning Phase") 151 | print("="*60) 152 | 153 | if getattr(self.config, 'pre_trained_tokenizer', True): 154 | if not os.path.exists(self.config.finetuned_tokenizer_path): 155 | raise FileNotFoundError(f"Fine-tuned tokenizer does not exist: {self.config.finetuned_tokenizer_path}") 156 | 157 | _, basemodel_exists = self._check_existing_models() 158 | if basemodel_exists and self.config.skip_existing: 159 | print("Basemodel model already exists, skipping training") 160 | return True 161 | 162 | log_dir = os.path.join(self.config.base_save_path, "logs") 163 | logger = setup_basemodel_logging(self.config.exp_name, log_dir, self.rank) 164 | 165 | set_seed(self.config.seed) 166 | 167 | if getattr(self.config, 'pre_trained_tokenizer', True): 168 | logger.info("Loading fine-tuned tokenizer...") 169 | if self.rank == 0: 170 | print("Loading fine-tuned tokenizer...") 171 | tokenizer = KronosTokenizer.from_pretrained(self.config.finetuned_tokenizer_path) 172 | else: 173 | if self.rank == 0: 174 | print("pre_trained_tokenizer=False, randomly initializing Tokenizer architecture for Predictor training") 175 | import json 176 | cfg_path = os.path.join(self.config.pretrained_tokenizer_path, 'config.json') 177 | with open(cfg_path, 'r') as f: 178 | arch = json.load(f) 179 | tokenizer = KronosTokenizer( 180 | d_in=arch.get('d_in', 6), 181 | d_model=arch.get('d_model', 256), 182 | n_heads=arch.get('n_heads', 4), 183 | ff_dim=arch.get('ff_dim', 512), 184 | n_enc_layers=arch.get('n_enc_layers', 4), 185 | n_dec_layers=arch.get('n_dec_layers', 4), 186 | ffn_dropout_p=arch.get('ffn_dropout_p', 0.0), 187 | attn_dropout_p=arch.get('attn_dropout_p', 0.0), 188 | resid_dropout_p=arch.get('resid_dropout_p', 0.0), 189 | s1_bits=arch.get('s1_bits', 10), 190 | s2_bits=arch.get('s2_bits', 10), 191 | beta=arch.get('beta', 0.05), 192 | gamma0=arch.get('gamma0', 1.0), 193 | gamma=arch.get('gamma', 1.1), 194 | zeta=arch.get('zeta', 0.05), 195 | group_size=arch.get('group_size', 4) 196 | ) 197 | tokenizer = tokenizer.to(self.device) 198 | 199 | if getattr(self.config, 'pre_trained_predictor', True): 200 | logger.info("Loading pretrained predictor...") 201 | if self.rank == 0: 202 | print("Loading pretrained predictor...") 203 | model = Kronos.from_pretrained(self.config.pretrained_predictor_path) 204 | else: 205 | if self.rank == 0: 206 | print("pre_trained_predictor=False, randomly initializing Predictor architecture") 207 | import json 208 | cfg_path = os.path.join(self.config.pretrained_predictor_path, 'config.json') 209 | with open(cfg_path, 'r') as f: 210 | arch = json.load(f) 211 | print("model_config: ", arch) 212 | model = Kronos( 213 | s1_bits=arch.get('s1_bits', 10), 214 | s2_bits=arch.get('s2_bits', 10), 215 | n_layers=arch.get('n_layers', 12), 216 | d_model=arch.get('d_model', 832), 217 | n_heads=arch.get('n_heads', 16), 218 | ff_dim=arch.get('ff_dim', 2048), 219 | ffn_dropout_p=arch.get('ffn_dropout_p', 0.2), 220 | attn_dropout_p=arch.get('attn_dropout_p', 0.0), 221 | resid_dropout_p=arch.get('resid_dropout_p', 0.2), 222 | token_dropout_p=arch.get('token_dropout_p', 0.0), 223 | learn_te=arch.get('learn_te', True) 224 | ) 225 | model = model.to(self.device) 226 | 227 | model_size = sum(p.numel() for p in model.parameters()) 228 | logger.info(f"Model parameters: {model_size:,}") 229 | if self.rank == 0: 230 | print(f"Model parameters: {model_size:,}") 231 | 232 | logger.info("=== Training Configuration ===") 233 | logger.info(f"Data path: {self.config.data_path}") 234 | logger.info(f"Lookback window: {self.config.lookback_window}") 235 | logger.info(f"Predict window: {self.config.predict_window}") 236 | logger.info(f"Batch size: {self.config.batch_size}") 237 | logger.info(f"Learning rate: {self.config.predictor_learning_rate}") 238 | logger.info(f"Training epochs: {self.config.basemodel_epochs}") 239 | logger.info(f"Device: {self.device}") 240 | logger.info(f"Tokenizer path: {self.config.finetuned_tokenizer_path}") 241 | logger.info(f"Pretrained model path: {self.config.pretrained_predictor_path}") 242 | 243 | logger.info("Starting fine-tuning training...") 244 | if self.rank == 0: 245 | print("Starting fine-tuning training...") 246 | start_time = time.time() 247 | best_val_loss = train_model( 248 | model, 249 | tokenizer, 250 | self.device, 251 | self.config, 252 | self.config.basemodel_save_path, 253 | logger, 254 | ) 255 | training_time = time.time() - start_time 256 | 257 | final_msg = f"Basemodel training completed! Best validation loss: {best_val_loss:.4f}\nTraining time: {training_time/60:.2f} minutes\nModel saved to: {self.config.basemodel_save_path}" 258 | logger.info(final_msg) 259 | if self.rank == 0: 260 | print(f"\n{final_msg}") 261 | 262 | return True 263 | 264 | def run_training(self): 265 | if self.rank == 0: 266 | print("Starting Kronos model sequential fine-tuning training") 267 | print(f"Experiment name: {self.config.experiment_name}") 268 | print(f"Experiment description: {self.config.experiment_description}") 269 | 270 | self._setup_distributed() 271 | 272 | self._create_directories() 273 | 274 | tokenizer_exists, basemodel_exists = self._check_existing_models() 275 | 276 | total_start_time = time.time() 277 | 278 | try: 279 | if self.config.train_tokenizer: 280 | success = self.train_tokenizer_phase() 281 | if not success: 282 | print("Tokenizer training failed, terminating training") 283 | return False 284 | else: 285 | print("Skipping Tokenizer training phase") 286 | 287 | if self.config.train_basemodel: 288 | success = self.train_basemodel_phase() 289 | if not success: 290 | print("Basemodel training failed, terminating training") 291 | return False 292 | else: 293 | print("Skipping Basemodel training phase") 294 | 295 | total_time = time.time() - total_start_time 296 | 297 | if self.rank == 0: 298 | print("\n" + "="*60) 299 | print("Training completed!") 300 | print("="*60) 301 | print(f"Total training time: {total_time/60:.2f} minutes") 302 | print(f"Tokenizer model: {self.config.tokenizer_best_model_path}") 303 | print(f"Basemodel model: {self.config.basemodel_best_model_path}") 304 | print("="*60) 305 | 306 | return True 307 | 308 | except Exception as e: 309 | if self.rank == 0: 310 | print(f"Error occurred during training: {str(e)}") 311 | import traceback 312 | traceback.print_exc() 313 | return False 314 | 315 | finally: 316 | pass 317 | 318 | 319 | def main(): 320 | parser = argparse.ArgumentParser(description='Kronos Model Sequential Fine-tuning Training') 321 | parser.add_argument('--config', type=str, default='config.yaml', 322 | help='Configuration file path (default: config.yaml)') 323 | parser.add_argument('--skip-tokenizer', action='store_true', 324 | help='Skip tokenizer training phase') 325 | parser.add_argument('--skip-basemodel', action='store_true', 326 | help='Skip basemodel training phase') 327 | parser.add_argument('--skip-existing', action='store_true', 328 | help='Skip training for existing models') 329 | 330 | args = parser.parse_args() 331 | 332 | trainer = SequentialTrainer(args.config) 333 | 334 | if args.skip_tokenizer: 335 | trainer.config.train_tokenizer = False 336 | if args.skip_basemodel: 337 | trainer.config.train_basemodel = False 338 | if args.skip_existing: 339 | trainer.config.skip_existing = True 340 | 341 | success = trainer.run_training() 342 | 343 | if success: 344 | print("Training completed successfully!") 345 | if dist.is_available() and dist.is_initialized(): 346 | dist.barrier() 347 | dist.destroy_process_group() 348 | sys.exit(0) 349 | else: 350 | print("Training failed!") 351 | if dist.is_available() and dist.is_initialized(): 352 | try: 353 | dist.barrier() 354 | dist.destroy_process_group() 355 | except Exception: 356 | pass 357 | sys.exit(1) 358 | 359 | 360 | if __name__ == "__main__": 361 | main() 362 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
41 |
42 |
43 |
44 |
57 | 58 | ## 📜 Introduction 59 | 60 | **Kronos** is a family of decoder-only foundation models, pre-trained specifically for the "language" of financial markets—K-line sequences. Unlike general-purpose TSFMs, Kronos is designed to handle the unique, high-noise characteristics of financial data. It leverages a novel two-stage framework: 61 | 1. A specialized tokenizer first quantizes continuous, multi-dimensional K-line data (OHLCV) into **hierarchical discrete tokens**. 62 | 2. A large, autoregressive Transformer is then pre-trained on these tokens, enabling it to serve as a unified model for diverse quantitative tasks. 63 | 64 |
65 |
66 |
210 |
211 |
298 |
299 |