├── finetune ├── utils │ ├── __init__.py │ └── training_utils.py ├── qlib_data_preprocess.py ├── dataset.py ├── config.py ├── train_predictor.py ├── train_tokenizer.py └── qlib_test.py ├── facecat ├── facecatcpp.dll ├── image │ ├── 主界面.png │ ├── 五档行情.png │ ├── 代码表.png │ ├── 分时图.png │ ├── 回测模式.png │ ├── 多k线.png │ ├── 训练参数.png │ ├── 预测k线.png │ ├── 预测模式.png │ └── 预测界面.png ├── model │ ├── Kronos-small │ │ └── config.json │ ├── Kronos-Tokenizer-base │ │ └── config.json │ ├── __init__.py │ ├── module.py │ └── kronos.py └── stock.py ├── requirements.txt ├── model ├── __init__.py ├── module.py └── kronos.py ├── LICENSE ├── examples ├── prediction_wo_vol_example.py ├── cpu_prediction_wo_vol_examples.py ├── prediction_example.py └── cpu_prediction_example.py ├── README.md └── README.en.md /finetune/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /facecat/facecatcpp.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fidingks/facecat-kronos/HEAD/facecat/facecatcpp.dll -------------------------------------------------------------------------------- /facecat/image/主界面.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fidingks/facecat-kronos/HEAD/facecat/image/主界面.png -------------------------------------------------------------------------------- /facecat/image/五档行情.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fidingks/facecat-kronos/HEAD/facecat/image/五档行情.png -------------------------------------------------------------------------------- /facecat/image/代码表.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fidingks/facecat-kronos/HEAD/facecat/image/代码表.png -------------------------------------------------------------------------------- /facecat/image/分时图.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fidingks/facecat-kronos/HEAD/facecat/image/分时图.png -------------------------------------------------------------------------------- /facecat/image/回测模式.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fidingks/facecat-kronos/HEAD/facecat/image/回测模式.png -------------------------------------------------------------------------------- /facecat/image/多k线.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fidingks/facecat-kronos/HEAD/facecat/image/多k线.png -------------------------------------------------------------------------------- /facecat/image/训练参数.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fidingks/facecat-kronos/HEAD/facecat/image/训练参数.png -------------------------------------------------------------------------------- /facecat/image/预测k线.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fidingks/facecat-kronos/HEAD/facecat/image/预测k线.png -------------------------------------------------------------------------------- /facecat/image/预测模式.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fidingks/facecat-kronos/HEAD/facecat/image/预测模式.png -------------------------------------------------------------------------------- /facecat/image/预测界面.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fidingks/facecat-kronos/HEAD/facecat/image/预测界面.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | torch 4 | 5 | einops==0.8.1 6 | huggingface_hub==0.33.1 7 | matplotlib==3.9.3 8 | pandas==2.2.2 9 | tqdm==4.67.1 10 | -------------------------------------------------------------------------------- /facecat/model/Kronos-small/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attn_dropout_p": 0.1, 3 | "d_model": 512, 4 | "ff_dim": 1024, 5 | "ffn_dropout_p": 0.25, 6 | "learn_te": true, 7 | "n_heads": 8, 8 | "n_layers": 8, 9 | "resid_dropout_p": 0.25, 10 | "s1_bits": 10, 11 | "s2_bits": 10, 12 | "token_dropout_p": 0.1 13 | } -------------------------------------------------------------------------------- /facecat/model/Kronos-Tokenizer-base/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attn_dropout_p": 0.0, 3 | "beta": 0.05, 4 | "d_in": 6, 5 | "d_model": 256, 6 | "ff_dim": 512, 7 | "ffn_dropout_p": 0.0, 8 | "gamma": 1.1, 9 | "gamma0": 1.0, 10 | "group_size": 4, 11 | "n_dec_layers": 4, 12 | "n_enc_layers": 4, 13 | "n_heads": 4, 14 | "resid_dropout_p": 0.0, 15 | "s1_bits": 10, 16 | "s2_bits": 10, 17 | "zeta": 0.05 18 | } 19 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .kronos import KronosTokenizer, Kronos, KronosPredictor 2 | 3 | model_dict = { 4 | 'kronos_tokenizer': KronosTokenizer, 5 | 'kronos': Kronos, 6 | 'kronos_predictor': KronosPredictor 7 | } 8 | 9 | 10 | def get_model_class(model_name): 11 | if model_name in model_dict: 12 | return model_dict[model_name] 13 | else: 14 | print(f"Model {model_name} not found in model_dict") 15 | raise NotImplementedError 16 | 17 | 18 | -------------------------------------------------------------------------------- /facecat/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .kronos import KronosTokenizer, Kronos, KronosPredictor 2 | 3 | model_dict = { 4 | 'kronos_tokenizer': KronosTokenizer, 5 | 'kronos': Kronos, 6 | 'kronos_predictor': KronosPredictor 7 | } 8 | 9 | 10 | def get_model_class(model_name): 11 | if model_name in model_dict: 12 | return model_dict[model_name] 13 | else: 14 | print(f"Model {model_name} not found in model_dict") 15 | raise NotImplementedError 16 | 17 | 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 ShiYu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/prediction_wo_vol_example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import sys 4 | sys.path.append("../") 5 | from model import Kronos, KronosTokenizer, KronosPredictor 6 | 7 | 8 | def plot_prediction(kline_df, pred_df): 9 | pred_df.index = kline_df.index[-pred_df.shape[0]:] 10 | sr_close = kline_df['close'] 11 | sr_pred_close = pred_df['close'] 12 | sr_close.name = 'Ground Truth' 13 | sr_pred_close.name = "Prediction" 14 | 15 | close_df = pd.concat([sr_close, sr_pred_close], axis=1) 16 | 17 | fig, ax = plt.subplots(1, 1, figsize=(8, 4)) 18 | 19 | ax.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 20 | ax.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 21 | ax.set_ylabel('Close Price', fontsize=14) 22 | ax.legend(loc='lower left', fontsize=12) 23 | ax.grid(True) 24 | 25 | plt.tight_layout() 26 | plt.show() 27 | 28 | 29 | # 1. Load Model and Tokenizer 30 | tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") 31 | model = Kronos.from_pretrained("NeoQuasar/Kronos-small") 32 | 33 | # 2. Instantiate Predictor 34 | predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) 35 | 36 | # 3. Prepare Data 37 | df = pd.read_csv("./data/XSHG_5min_600977.csv") 38 | df['timestamps'] = pd.to_datetime(df['timestamps']) 39 | 40 | lookback = 400 41 | pred_len = 120 42 | 43 | x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close']] 44 | x_timestamp = df.loc[:lookback-1, 'timestamps'] 45 | y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps'] 46 | 47 | # 4. Make Prediction 48 | pred_df = predictor.predict( 49 | df=x_df, 50 | x_timestamp=x_timestamp, 51 | y_timestamp=y_timestamp, 52 | pred_len=pred_len, 53 | T=1.0, 54 | top_p=0.9, 55 | sample_count=1, 56 | verbose=True 57 | ) 58 | 59 | # 5. Visualize Results 60 | print("Forecasted Data Head:") 61 | print(pred_df.head()) 62 | 63 | # Combine historical and forecasted data for plotting 64 | kline_df = df.loc[:lookback+pred_len-1] 65 | 66 | # visualize 67 | plot_prediction(kline_df, pred_df) 68 | 69 | -------------------------------------------------------------------------------- /examples/cpu_prediction_wo_vol_examples.py: -------------------------------------------------------------------------------- 1 | # 官方案例,可在cpu环境下运行,需移到官方examples目录下运行 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import sys 5 | sys.path.append("../") 6 | from model import Kronos, KronosTokenizer, KronosPredictor 7 | import torch 8 | 9 | 10 | def plot_prediction(kline_df, pred_df): 11 | pred_df.index = kline_df.index[-pred_df.shape[0]:] 12 | sr_close = kline_df['close'] 13 | sr_pred_close = pred_df['close'] 14 | sr_close.name = 'Ground Truth' 15 | sr_pred_close.name = "Prediction" 16 | 17 | close_df = pd.concat([sr_close, sr_pred_close], axis=1) 18 | 19 | fig, ax = plt.subplots(1, 1, figsize=(8, 4)) 20 | 21 | ax.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 22 | ax.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 23 | ax.set_ylabel('Close Price', fontsize=14) 24 | ax.legend(loc='lower left', fontsize=12) 25 | ax.grid(True) 26 | 27 | plt.tight_layout() 28 | plt.show() 29 | 30 | 31 | # 1. Load Model and Tokenizer 32 | tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") 33 | model = Kronos.from_pretrained("NeoQuasar/Kronos-small") 34 | 35 | # 2. Instantiate Predictor 36 | # predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) 37 | if torch.cuda.is_available(): 38 | device = "cuda:0" 39 | elif torch.backends.mps.is_available(): 40 | device = "mps" 41 | else: 42 | device = "cpu" 43 | print(f"Using device: {device}") 44 | predictor = KronosPredictor(model, tokenizer, device=device, max_context=512) 45 | 46 | # 3. Prepare Data 47 | df = pd.read_csv("./data/XSHG_5min_600977.csv") 48 | df['timestamps'] = pd.to_datetime(df['timestamps']) 49 | 50 | lookback = 400 51 | pred_len = 120 52 | 53 | x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close']] 54 | x_timestamp = df.loc[:lookback-1, 'timestamps'] 55 | y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps'] 56 | 57 | # 4. Make Prediction 58 | pred_df = predictor.predict( 59 | df=x_df, 60 | x_timestamp=x_timestamp, 61 | y_timestamp=y_timestamp, 62 | pred_len=pred_len, 63 | T=1.0, 64 | top_p=0.9, 65 | sample_count=1, 66 | verbose=True 67 | ) 68 | 69 | # 5. Visualize Results 70 | print("Forecasted Data Head:") 71 | print(pred_df.head()) 72 | 73 | # Combine historical and forecasted data for plotting 74 | kline_df = df.loc[:lookback+pred_len-1] 75 | 76 | # visualize 77 | plot_prediction(kline_df, pred_df) 78 | 79 | -------------------------------------------------------------------------------- /examples/prediction_example.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import sys 4 | sys.path.append("../") 5 | from model import Kronos, KronosTokenizer, KronosPredictor 6 | 7 | 8 | def plot_prediction(kline_df, pred_df): 9 | pred_df.index = kline_df.index[-pred_df.shape[0]:] 10 | sr_close = kline_df['close'] 11 | sr_pred_close = pred_df['close'] 12 | sr_close.name = 'Ground Truth' 13 | sr_pred_close.name = "Prediction" 14 | 15 | sr_volume = kline_df['volume'] 16 | sr_pred_volume = pred_df['volume'] 17 | sr_volume.name = 'Ground Truth' 18 | sr_pred_volume.name = "Prediction" 19 | 20 | close_df = pd.concat([sr_close, sr_pred_close], axis=1) 21 | volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1) 22 | 23 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True) 24 | 25 | ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 26 | ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 27 | ax1.set_ylabel('Close Price', fontsize=14) 28 | ax1.legend(loc='lower left', fontsize=12) 29 | ax1.grid(True) 30 | 31 | ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 32 | ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 33 | ax2.set_ylabel('Volume', fontsize=14) 34 | ax2.legend(loc='upper left', fontsize=12) 35 | ax2.grid(True) 36 | 37 | plt.tight_layout() 38 | plt.show() 39 | 40 | 41 | # 1. Load Model and Tokenizer 42 | tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") 43 | model = Kronos.from_pretrained("NeoQuasar/Kronos-small") 44 | 45 | # 2. Instantiate Predictor 46 | predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) 47 | 48 | # 3. Prepare Data 49 | df = pd.read_csv("./day/SH600006.txt") 50 | df['timestamps'] = pd.to_datetime(df['timestamps']) 51 | 52 | lookback = 400 53 | pred_len = 120 54 | 55 | x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']] 56 | x_timestamp = df.loc[:lookback-1, 'timestamps'] 57 | y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps'] 58 | 59 | # 4. Make Prediction 60 | pred_df = predictor.predict( 61 | df=x_df, 62 | x_timestamp=x_timestamp, 63 | y_timestamp=y_timestamp, 64 | pred_len=pred_len, 65 | T=1.0, 66 | top_p=0.9, 67 | sample_count=1, 68 | verbose=True 69 | ) 70 | 71 | # 5. Visualize Results 72 | print("Forecasted Data Head:") 73 | print(pred_df.head()) 74 | 75 | # Combine historical and forecasted data for plotting 76 | kline_df = df.loc[:lookback+pred_len-1] 77 | 78 | # visualize 79 | plot_prediction(kline_df, pred_df) 80 | 81 | -------------------------------------------------------------------------------- /examples/cpu_prediction_example.py: -------------------------------------------------------------------------------- 1 | # 官方案例,可在cpu环境下运行,需移到官方examples目录下运行 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import sys 5 | sys.path.append("../") 6 | from model import Kronos, KronosTokenizer, KronosPredictor 7 | import torch 8 | 9 | 10 | def plot_prediction(kline_df, pred_df): 11 | pred_df.index = kline_df.index[-pred_df.shape[0]:] 12 | sr_close = kline_df['close'] 13 | sr_pred_close = pred_df['close'] 14 | sr_close.name = 'Ground Truth' 15 | sr_pred_close.name = "Prediction" 16 | 17 | sr_volume = kline_df['volume'] 18 | sr_pred_volume = pred_df['volume'] 19 | sr_volume.name = 'Ground Truth' 20 | sr_pred_volume.name = "Prediction" 21 | 22 | close_df = pd.concat([sr_close, sr_pred_close], axis=1) 23 | volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1) 24 | 25 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True) 26 | 27 | ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 28 | ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 29 | ax1.set_ylabel('Close Price', fontsize=14) 30 | ax1.legend(loc='lower left', fontsize=12) 31 | ax1.grid(True) 32 | 33 | ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) 34 | ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5) 35 | ax2.set_ylabel('Volume', fontsize=14) 36 | ax2.legend(loc='upper left', fontsize=12) 37 | ax2.grid(True) 38 | 39 | plt.tight_layout() 40 | plt.show() 41 | 42 | 43 | # 1. Load Model and Tokenizer 44 | tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") 45 | model = Kronos.from_pretrained("NeoQuasar/Kronos-small") 46 | 47 | # 2. Instantiate Predictor 48 | # predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) 49 | if torch.cuda.is_available(): 50 | device = "cuda:0" 51 | elif torch.backends.mps.is_available(): 52 | device = "mps" 53 | else: 54 | device = "cpu" 55 | print(f"Using device: {device}") 56 | predictor = KronosPredictor(model, tokenizer, device=device, max_context=512) 57 | 58 | # 3. Prepare Data 59 | df = pd.read_csv("./data/XSHG_5min_600977.csv") 60 | df['timestamps'] = pd.to_datetime(df['timestamps']) 61 | 62 | lookback = 400 63 | pred_len = 120 64 | 65 | x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']] 66 | x_timestamp = df.loc[:lookback-1, 'timestamps'] 67 | y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps'] 68 | 69 | # 4. Make Prediction 70 | pred_df = predictor.predict( 71 | df=x_df, 72 | x_timestamp=x_timestamp, 73 | y_timestamp=y_timestamp, 74 | pred_len=pred_len, 75 | T=1.0, 76 | top_p=0.9, 77 | sample_count=1, 78 | verbose=True 79 | ) 80 | 81 | # 5. Visualize Results 82 | print("Forecasted Data Head:") 83 | print(pred_df.head()) 84 | 85 | # Combine historical and forecasted data for plotting 86 | kline_df = df.loc[:lookback+pred_len-1] 87 | 88 | # visualize 89 | plot_prediction(kline_df, pred_df) 90 | 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FaceCat-Kronos 金融量化预测工具 2 | 3 | ## 项目简介 4 | 5 | **FaceCat-Kronos** 是由 **花卷猫量化研究团队** 打造的一款金融量化工具。本项目基于清华大学最新开源的K线预测框架 **Kronos**,融合了前沿的人工智能技术,旨在为金融市场提供科学的分析与预测能力。 6 | 7 | 本工具能够对股票历史数据进行深度预训练,实现精准的做市商K线规划,并对未来市场走势进行科学推演,适用于量化研究、策略研发、交易决策支持、投研汇报、教学演示、二次开发。无论是基金、私募、荐股机构、专业投资者,或是量化研究员、业余爱好者,FaceCat-Kronos 都能为您提供k线做图能力,成为洞察市场的助手。 8 | 9 | ## 使用技术 10 | 11 | * **花卷猫Python**: 由花卷猫量化研究团队自研的图形框架,本项目UI(k线、图表、指标)全部由该框架实现。 12 | * **PySide**: 可在MacOs上运行。 13 | * **Kronos**: 清华K线预测模型。 14 | 15 | ## 数据源 16 | 17 | 股票数据全部来自 **花卷猫量化研究团队** 自研的行情服务器,项目地址 [stockquote行情服务端](https://jjmfc.com/software/stockquote.zip) ,可下载后部署在自己的服务器。 18 | 19 | ## 核心功能 20 | 21 | * **未来K线预测**: 利用Kronos框架,对未来股价走势、成交量等进行多维度预测,生成虚拟K线。 22 | * **历史数据回测**: 在历史数据上验证预测模型的准确性,直观比较预测结果与真实走势。 23 | * **多周期K线分析**: 提供分时、日线、周线、月线等多周期K线面板,满足不同策略的分析需求。 24 | * **交互式图表**: 简洁直观的图形化界面,支持点击切换股票、调整预测参数等便捷操作。 25 | * **灵活参数调整**: 支持调整 `Temperature` (温度) 和 `top_p` (核心采样) 等参数,控制预测模型的行为。 26 | 27 | ## 软件界面 28 | 29 | ### 主界面 30 | ![主界面](facecat/image/主界面.png) 31 | 32 | ### 代码表 33 | ![代码表](facecat/image/代码表.png) 34 | 35 | ### 分时图 36 | ![分时图](facecat/image/分时图.png) 37 | 38 | ### 五档行情 39 | ![五档行情](facecat/image/五档行情.png) 40 | 41 | ### 预测界面 42 | ![预测界面](facecat/image/预测界面.png) 43 | 44 | ### 预测模式 45 | ![预测模式](facecat/image/预测模式.png) 46 | 47 | ### 回测模式 48 | ![回测模式](facecat/image/回测模式.png) 49 | 50 | ### 预测k线 51 | ![预测k线](facecat/image/预测k线.png) 52 | 53 | ### 多K线面板 54 | ![多k线](facecat/image/多k线.png) 55 | 56 | ### 训练参数 57 | ![训练参数](facecat/image/训练参数.png) 58 | 59 | ## 我们的团队 60 | 61 | 花卷猫量化研究团队的成员均来自国内外金融机构及科技公司的量化部门: 62 | 63 | * 大智慧(龙软) 64 | * 东方财富 65 | * 东吴证券 66 | * 广发证券 67 | * 东海证券 68 | * 山西证券 69 | * 湘财证券 70 | * 华泰证券 71 | * 恒泰期货 72 | * 德意志银行 73 | 74 | 期间参与研发的系统或模块被大部分的证券期货公司、公募基金、私募或专业投资者使用,我们致力于将前沿的AI技术和市场理解融合,为专业投资者提供决策支持工具。 75 | 76 | ## 快速开始 77 | 78 | 在开始之前,请先阅读以下几点,这能帮助您避免很多不必要的麻烦: 79 | 80 | 1. **Python 版本**: 建议使用 Python 3.10+,3.9版本在windows下可能无法调用模型文件。 81 | 2. **运行平台**: 原始代码默认使用 NVIDIA CUDA GPU (`device="cuda:0"`)。如果您是 Mac 或没有 NVIDIA 显卡的 Windows/Linux 用户,直接运行会报错。教程附有修改后的代码,使CPU也能顺畅运行。 82 | 3. **依赖安装**: `requirements.txt` 中可能遗漏了间接依赖 `safetensors`,我们已在安装步骤中补充。为加速下载,建议使用国内镜像源。 83 | 4. **运行路径**: 请务必在正确的目录下执行命令,避免出现路径错误。 84 | 85 | ### 部署与运行 86 | 87 | 1. **下载项目** 88 | * 访问 [facecat-kronos GitHub 仓库](https://github.com/Fidingks/facecat-kronos) 下载ZIP包,或使用 Git 克隆。 89 | * 解压后,使用 VSCode 或其他IDE打开项目文件夹。 90 | 91 | 2. **安装依赖包** 92 | ```bash 93 | # 使用清华镜像源安装 94 | pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 95 | 96 | # 安装可能遗漏的包 97 | pip install safetensors -i https://pypi.tuna.tsinghua.edu.cn/simple 98 | ``` 99 | 100 | 3. **运行官方示例 (可选)** 101 | * 如果您的设备配有NVIDIA GPU,可以先运行官方示例以验证环境。 102 | ```bash 103 | cd examples 104 | python prediction_example.py 105 | ``` 106 | * 若电脑无兼容的GPU,请使用我们提供的已修改好的CPU版本示例 (`cpu_prediction_example.py`, `cpu_prediction_wo_vol_examples.py`)。 107 | * 可能会出现无法加载在线模型,请到https://huggingface.co/NeoQuasar/Kronos-small/tree/main 下载模型放到model目录下 108 | * 运行成功后,可以与在 `figures` 目录下的 `prediction_example.png` 图片进行对比。 109 | 110 | 4. **运行 FaceCat-Kronos** 111 | * 首先确保您位于项目的根目录。 112 | ```bash 113 | # 切换到 facecat 目录 114 | cd facecat 115 | 116 | # 运行主程序 117 | python main.py 118 | ``` 119 | 120 | ### 使用说明 121 | 122 | * **预测/回测**: 启动程序后,点击右侧的 **“预测”** 按钮,即可使用历史数据预测未来走势。分割线右侧的虚K线即为预测结果。您也可以在下拉菜单中切换至 **“回测模式”**,进行历史数据比对。 123 | * **切换股票**: 在主界面左侧的表格中点击任意股票,即可加载并分析其数据。 124 | * **参数调整**: 125 | * `T` (Temperature/温度): 范围0-100,数值越大,预测结果越大胆、越发散。 126 | * `topP` (Top-p/概率阈值): 范围0-1,数值越大,模型选择的候选项越集中于高概率范围,结果更趋于合理。 127 | * **界面导航**: 128 | * **预测界面**: 使用日k线进行历史回测或者未来预测 129 | * **主界面**: 包含核心的股票列表、分时图、五档行情面板和多周期K线图。 130 | * **多K线**: 独立的多周期K线分析面板。 131 | 132 | -------------------------------------------------------------------------------- /finetune/utils/training_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import datetime 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | def setup_ddp(): 10 | """ 11 | Initializes the distributed data parallel environment. 12 | 13 | This function relies on environment variables set by `torchrun` or a similar 14 | launcher. It initializes the process group and sets the CUDA device for the 15 | current process. 16 | 17 | Returns: 18 | tuple: A tuple containing (rank, world_size, local_rank). 19 | """ 20 | if not dist.is_available(): 21 | raise RuntimeError("torch.distributed is not available.") 22 | 23 | dist.init_process_group(backend="nccl") 24 | rank = int(os.environ["RANK"]) 25 | world_size = int(os.environ["WORLD_SIZE"]) 26 | local_rank = int(os.environ["LOCAL_RANK"]) 27 | torch.cuda.set_device(local_rank) 28 | print( 29 | f"[DDP Setup] Global Rank: {rank}/{world_size}, " 30 | f"Local Rank (GPU): {local_rank} on device {torch.cuda.current_device()}" 31 | ) 32 | return rank, world_size, local_rank 33 | 34 | 35 | def cleanup_ddp(): 36 | """Cleans up the distributed process group.""" 37 | if dist.is_initialized(): 38 | dist.destroy_process_group() 39 | 40 | 41 | def set_seed(seed: int, rank: int = 0): 42 | """ 43 | Sets the random seed for reproducibility across all relevant libraries. 44 | 45 | Args: 46 | seed (int): The base seed value. 47 | rank (int): The process rank, used to ensure different processes have 48 | different seeds, which can be important for data loading. 49 | """ 50 | actual_seed = seed + rank 51 | random.seed(actual_seed) 52 | np.random.seed(actual_seed) 53 | torch.manual_seed(actual_seed) 54 | if torch.cuda.is_available(): 55 | torch.cuda.manual_seed_all(actual_seed) 56 | # The two lines below can impact performance, so they are often 57 | # reserved for final experiments where reproducibility is critical. 58 | torch.backends.cudnn.deterministic = True 59 | torch.backends.cudnn.benchmark = False 60 | 61 | 62 | def get_model_size(model: torch.nn.Module) -> str: 63 | """ 64 | Calculates the number of trainable parameters in a PyTorch model and returns 65 | it as a human-readable string. 66 | 67 | Args: 68 | model (torch.nn.Module): The PyTorch model. 69 | 70 | Returns: 71 | str: A string representing the model size (e.g., "175.0B", "7.1M", "50.5K"). 72 | """ 73 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 74 | 75 | if total_params >= 1e9: 76 | return f"{total_params / 1e9:.1f}B" # Billions 77 | elif total_params >= 1e6: 78 | return f"{total_params / 1e6:.1f}M" # Millions 79 | else: 80 | return f"{total_params / 1e3:.1f}K" # Thousands 81 | 82 | 83 | def reduce_tensor(tensor: torch.Tensor, world_size: int, op=dist.ReduceOp.SUM) -> torch.Tensor: 84 | """ 85 | Reduces a tensor's value across all processes in a distributed setup. 86 | 87 | Args: 88 | tensor (torch.Tensor): The tensor to be reduced. 89 | world_size (int): The total number of processes. 90 | op (dist.ReduceOp, optional): The reduction operation (SUM, AVG, etc.). 91 | Defaults to dist.ReduceOp.SUM. 92 | 93 | Returns: 94 | torch.Tensor: The reduced tensor, which will be identical on all processes. 95 | """ 96 | rt = tensor.clone() 97 | dist.all_reduce(rt, op=op) 98 | # Note: `dist.ReduceOp.AVG` is available in newer torch versions. 99 | # For compatibility, manual division is sometimes used after a SUM. 100 | if op == dist.ReduceOp.AVG: 101 | rt /= world_size 102 | return rt 103 | 104 | 105 | def format_time(seconds: float) -> str: 106 | """ 107 | Formats a duration in seconds into a human-readable H:M:S string. 108 | 109 | Args: 110 | seconds (float): The total seconds. 111 | 112 | Returns: 113 | str: The formatted time string (e.g., "0:15:32"). 114 | """ 115 | return str(datetime.timedelta(seconds=int(seconds))) 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /finetune/qlib_data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import pandas as pd 5 | import qlib 6 | from qlib.config import REG_CN 7 | from qlib.data import D 8 | from qlib.data.dataset.loader import QlibDataLoader 9 | from tqdm import trange 10 | 11 | from config import Config 12 | 13 | 14 | class QlibDataPreprocessor: 15 | """ 16 | A class to handle the loading, processing, and splitting of Qlib financial data. 17 | """ 18 | 19 | def __init__(self): 20 | """Initializes the preprocessor with configuration and data fields.""" 21 | self.config = Config() 22 | self.data_fields = ['open', 'close', 'high', 'low', 'volume', 'vwap'] 23 | self.data = {} # A dictionary to store processed data for each symbol. 24 | 25 | def initialize_qlib(self): 26 | """Initializes the Qlib environment.""" 27 | print("Initializing Qlib...") 28 | qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN) 29 | 30 | def load_qlib_data(self): 31 | """ 32 | Loads raw data from Qlib, processes it symbol by symbol, and stores 33 | it in the `self.data` attribute. 34 | """ 35 | print("Loading and processing data from Qlib...") 36 | data_fields_qlib = ['$' + f for f in self.data_fields] 37 | cal: np.ndarray = D.calendar() 38 | 39 | # Determine the actual start and end times to load, including buffer for lookback and predict windows. 40 | start_index = cal.searchsorted(pd.Timestamp(self.config.dataset_begin_time)) 41 | end_index = cal.searchsorted(pd.Timestamp(self.config.dataset_end_time)) 42 | real_start_time = cal[start_index - self.config.lookback_window] 43 | 44 | if cal[end_index] != pd.Timestamp(self.config.dataset_end_time): 45 | end_index -= 1 46 | real_end_time = cal[end_index + self.config.predict_window] 47 | 48 | # Load data using Qlib's data loader. 49 | data_df = QlibDataLoader(config=data_fields_qlib).load( 50 | self.config.instrument, real_start_time, real_end_time 51 | ) 52 | data_df = data_df.stack().unstack(level=1) # Reshape for easier access. 53 | 54 | symbol_list = list(data_df.columns) 55 | for i in trange(len(symbol_list), desc="Processing Symbols"): 56 | symbol = symbol_list[i] 57 | symbol_df = data_df[symbol] 58 | 59 | # Pivot the table to have features as columns and datetime as index. 60 | symbol_df = symbol_df.reset_index().rename(columns={'level_1': 'field'}) 61 | symbol_df = pd.pivot(symbol_df, index='datetime', columns='field', values=symbol) 62 | symbol_df = symbol_df.rename(columns={f'${field}': field for field in self.data_fields}) 63 | 64 | # Calculate amount and select final features. 65 | symbol_df['vol'] = symbol_df['volume'] 66 | symbol_df['amt'] = (symbol_df['open'] + symbol_df['high'] + symbol_df['low'] + symbol_df['close']) / 4 * symbol_df['vol'] 67 | symbol_df = symbol_df[self.config.feature_list] 68 | 69 | # Filter out symbols with insufficient data. 70 | symbol_df = symbol_df.dropna() 71 | if len(symbol_df) < self.config.lookback_window + self.config.predict_window + 1: 72 | continue 73 | 74 | self.data[symbol] = symbol_df 75 | 76 | def prepare_dataset(self): 77 | """ 78 | Splits the loaded data into train, validation, and test sets and saves them to disk. 79 | """ 80 | print("Splitting data into train, validation, and test sets...") 81 | train_data, val_data, test_data = {}, {}, {} 82 | 83 | symbol_list = list(self.data.keys()) 84 | for i in trange(len(symbol_list), desc="Preparing Datasets"): 85 | symbol = symbol_list[i] 86 | symbol_df = self.data[symbol] 87 | 88 | # Define time ranges from config. 89 | train_start, train_end = self.config.train_time_range 90 | val_start, val_end = self.config.val_time_range 91 | test_start, test_end = self.config.test_time_range 92 | 93 | # Create boolean masks for each dataset split. 94 | train_mask = (symbol_df.index >= train_start) & (symbol_df.index <= train_end) 95 | val_mask = (symbol_df.index >= val_start) & (symbol_df.index <= val_end) 96 | test_mask = (symbol_df.index >= test_start) & (symbol_df.index <= test_end) 97 | 98 | # Apply masks to create the final datasets. 99 | train_data[symbol] = symbol_df[train_mask] 100 | val_data[symbol] = symbol_df[val_mask] 101 | test_data[symbol] = symbol_df[test_mask] 102 | 103 | # Save the datasets using pickle. 104 | os.makedirs(self.config.dataset_path, exist_ok=True) 105 | with open(f"{self.config.dataset_path}/train_data.pkl", 'wb') as f: 106 | pickle.dump(train_data, f) 107 | with open(f"{self.config.dataset_path}/val_data.pkl", 'wb') as f: 108 | pickle.dump(val_data, f) 109 | with open(f"{self.config.dataset_path}/test_data.pkl", 'wb') as f: 110 | pickle.dump(test_data, f) 111 | 112 | print("Datasets prepared and saved successfully.") 113 | 114 | 115 | if __name__ == '__main__': 116 | # This block allows the script to be run directly to perform data preprocessing. 117 | preprocessor = QlibDataPreprocessor() 118 | preprocessor.initialize_qlib() 119 | preprocessor.load_qlib_data() 120 | preprocessor.prepare_dataset() 121 | -------------------------------------------------------------------------------- /README.en.md: -------------------------------------------------------------------------------- 1 | # FaceCat-Kronos: Financial Quantitative Forecasting Tool 2 | 3 | ## About This Project 4 | 5 | **FaceCat-Kronos** is a financial quantitative tool developed by the **Huajuanmao Quantitative Research Team**. This project is built upon **Kronos**, the latest open-source K-line forecasting framework from Tsinghua University. It integrates advanced artificial intelligence technology, aiming to provide scientific analysis and prediction capabilities for the financial market. 6 | 7 | This tool can perform in-depth pre-training on historical stock data for market-maker K-line planning and deduce future market trends. Whether you are a mutual fund, a private equity firm, or a professional stock recommendation agency, FaceCat-Kronos can provide you with K-line charting capabilities and serve as an assistant for market insights. 8 | 9 | ## Core Features 10 | 11 | * **Future K-Line Prediction**: Utilizes the Kronos framework to predict multi-dimensional data such as future stock price trends and trading volumes, generating virtual K-lines. 12 | * **Historical Data Backtesting**: Validates the accuracy of the prediction model on historical data, allowing for an intuitive comparison between predicted and actual trends. 13 | * **Multi-Period K-Line Analysis**: Provides various K-line panels, including intraday, daily, weekly, and monthly charts, to meet the needs of different trading strategies. 14 | * **Interactive Charts**: A clean and intuitive graphical user interface that supports clicking to switch stocks, adjusting prediction parameters, and other convenient operations. 15 | * **Flexible Parameter Tuning**: Supports adjusting parameters like `Temperature` and `top_p` (nucleus sampling) to control the behavior of the prediction model. 16 | 17 | ## Software Interface 18 | 19 | ### Main Interface 20 | ![Main Interface](facecat/image/主界面.png) 21 | 22 | ### Prediction Interface 23 | ![Prediction Interface](facecat/image/预测界面.png) 24 | 25 | ### Prediction Mode 26 | ![Prediction Mode](facecat/image/预测模式.png) 27 | 28 | ### Backtesting Mode 29 | ![Backtesting Mode](facecat/image/回测模式.png) 30 | 31 | ### Multi-Period K-Line Panel 32 | ![Multi-K-Line](facecat/image/多k线.png) 33 | 34 | ## Our Team 35 | 36 | The members of the Huajuanmao Quantitative Research Team come from the quantitative departments of domestic and international financial institutions and technology companies: 37 | 38 | * Great Wisdom (Longtop) 39 | * East Money 40 | * Soochow Securities 41 | * GF Securities 42 | * Donghai Securities 43 | * Shanxi Securities 44 | * Xiangcai Securities 45 | * Huatai Securities 46 | * Hengtai Futures 47 | * Deutsche Bank 48 | 49 | The systems or modules developed by our members have been used by a majority of securities and futures companies, mutual funds, private equities, and professional investors. We are committed to integrating cutting-edge AI technology with market understanding to provide decision-support tools for professional investors. 50 | 51 | ## Quick Start 52 | 53 | Before you begin, please read the following points to avoid unnecessary issues: 54 | 55 | 1. **Python Version**: Python 3.10+ is officially recommended. 56 | 2. **Platform**: The original code defaults to using an NVIDIA CUDA GPU (`device="cuda:0"`). If you are a Mac user or a Windows/Linux user without an NVIDIA graphics card, running it directly will cause an error. The tutorial includes modified code to enable CPU execution. 57 | 3. **Dependency Installation**: `requirements.txt` might be missing the indirect dependency `safetensors`, which we have added to the installation steps. 58 | 4. **Running Path**: Be sure to execute commands in the correct directory to avoid path-related errors. 59 | 60 | ### Deployment and Operation 61 | 62 | 1. **Download the Project** 63 | * Visit the [FaceCat-Kronos GitHub repository](https://github.com/Fidingks/facecat-kronos) to download the ZIP package, or clone it using Git. 64 | * After unzipping, open the project folder with VSCode or another IDE. 65 | 66 | 2. **Install Dependencies** 67 | ```bash 68 | # Install using the requirements file 69 | pip install -r requirements.txt 70 | 71 | # Install the potentially missing package 72 | pip install safetensors 73 | ``` 74 | 75 | 3. **Run the Official Example (Optional)** 76 | * If your machine is equipped with an NVIDIA GPU, you can run the official example first to verify the environment. 77 | ```bash 78 | cd examples 79 | python prediction_example.py 80 | ``` 81 | * If your computer does not have a compatible GPU, please use the modified CPU version examples we provide (`cpu_prediction_example.py`, `cpu_prediction_wo_vol_examples.py`). 82 | * After a successful run, you can compare your output with the `prediction_example.png` image in the `figures` directory. 83 | 84 | 4. **Run FaceCat-Kronos** 85 | * First, ensure you are in the project's root directory. 86 | ```bash 87 | # Change to the facecat directory 88 | cd facecat 89 | 90 | # Run the main program 91 | python main.py 92 | ``` 93 | 94 | ### Usage Instructions 95 | 96 | * **Prediction/Backtesting**: After launching the program, click the **"Predict"** button on the right to use historical data to forecast future trends. The virtual K-lines to the right of the separator line are the prediction results. You can also switch to **"Backtesting Mode"** from the dropdown menu to compare with historical data. 97 | * **Switching Stocks**: Click on any stock in the table on the left of the main interface to load and analyze its data. 98 | * **Parameter Tuning**: 99 | * `T` (Temperature): Range 0-100. A higher value leads to bolder and more diverse predictions. 100 | * `topP` (Top-p/Nucleus Sampling): Range 0-1. A higher value makes the model's choices more concentrated on high-probability options, resulting in more plausible outcomes. 101 | * **Interface Navigation**: 102 | * **Prediction Interface**: Use daily K-lines for historical backtesting or future prediction. 103 | * **Main Interface**: Includes the core stock list, intraday chart, Level 2 order book panel, and multi-period K-line chart. 104 | * **Multi K-Line**: A separate panel for multi-period K-line analysis. 105 | -------------------------------------------------------------------------------- /finetune/dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from config import Config 7 | 8 | 9 | class QlibDataset(Dataset): 10 | """ 11 | A PyTorch Dataset for handling Qlib financial time series data. 12 | 13 | This dataset pre-computes all possible start indices for sliding windows 14 | and then randomly samples from them during training/validation. 15 | 16 | Args: 17 | data_type (str): The type of dataset to load, either 'train' or 'val'. 18 | 19 | Raises: 20 | ValueError: If `data_type` is not 'train' or 'val'. 21 | """ 22 | 23 | def __init__(self, data_type: str = 'train'): 24 | self.config = Config() 25 | if data_type not in ['train', 'val']: 26 | raise ValueError("data_type must be 'train' or 'val'") 27 | self.data_type = data_type 28 | 29 | # Use a dedicated random number generator for sampling to avoid 30 | # interfering with other random processes (e.g., in model initialization). 31 | self.py_rng = random.Random(self.config.seed) 32 | 33 | # Set paths and number of samples based on the data type. 34 | if data_type == 'train': 35 | self.data_path = f"{self.config.dataset_path}/train_data.pkl" 36 | self.n_samples = self.config.n_train_iter 37 | else: 38 | self.data_path = f"{self.config.dataset_path}/val_data.pkl" 39 | self.n_samples = self.config.n_val_iter 40 | 41 | with open(self.data_path, 'rb') as f: 42 | self.data = pickle.load(f) 43 | 44 | self.window = self.config.lookback_window + self.config.predict_window + 1 45 | 46 | self.symbols = list(self.data.keys()) 47 | self.feature_list = self.config.feature_list 48 | self.time_feature_list = self.config.time_feature_list 49 | 50 | # Pre-compute all possible (symbol, start_index) pairs. 51 | self.indices = [] 52 | print(f"[{data_type.upper()}] Pre-computing sample indices...") 53 | for symbol in self.symbols: 54 | df = self.data[symbol].reset_index() 55 | series_len = len(df) 56 | num_samples = series_len - self.window + 1 57 | 58 | if num_samples > 0: 59 | # Generate time features and store them directly in the dataframe. 60 | df['minute'] = df['datetime'].dt.minute 61 | df['hour'] = df['datetime'].dt.hour 62 | df['weekday'] = df['datetime'].dt.weekday 63 | df['day'] = df['datetime'].dt.day 64 | df['month'] = df['datetime'].dt.month 65 | # Keep only necessary columns to save memory. 66 | self.data[symbol] = df[self.feature_list + self.time_feature_list] 67 | 68 | # Add all valid starting indices for this symbol to the global list. 69 | for i in range(num_samples): 70 | self.indices.append((symbol, i)) 71 | 72 | # The effective dataset size is the minimum of the configured iterations 73 | # and the total number of available samples. 74 | self.n_samples = min(self.n_samples, len(self.indices)) 75 | print(f"[{data_type.upper()}] Found {len(self.indices)} possible samples. Using {self.n_samples} per epoch.") 76 | 77 | def set_epoch_seed(self, epoch: int): 78 | """ 79 | Sets a new seed for the random sampler for each epoch. This is crucial 80 | for reproducibility in distributed training. 81 | 82 | Args: 83 | epoch (int): The current epoch number. 84 | """ 85 | epoch_seed = self.config.seed + epoch 86 | self.py_rng.seed(epoch_seed) 87 | 88 | def __len__(self) -> int: 89 | """Returns the number of samples per epoch.""" 90 | return self.n_samples 91 | 92 | def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: 93 | """ 94 | Retrieves a random sample from the dataset. 95 | 96 | Note: The `idx` argument is ignored. Instead, a random index is drawn 97 | from the pre-computed `self.indices` list using `self.py_rng`. This 98 | ensures random sampling over the entire dataset for each call. 99 | 100 | Args: 101 | idx (int): Ignored. 102 | 103 | Returns: 104 | tuple[torch.Tensor, torch.Tensor]: A tuple containing: 105 | - x_tensor (torch.Tensor): The normalized feature tensor. 106 | - x_stamp_tensor (torch.Tensor): The time feature tensor. 107 | """ 108 | # Select a random sample from the entire pool of indices. 109 | random_idx = self.py_rng.randint(0, len(self.indices) - 1) 110 | symbol, start_idx = self.indices[random_idx] 111 | 112 | # Extract the sliding window from the dataframe. 113 | df = self.data[symbol] 114 | end_idx = start_idx + self.window 115 | win_df = df.iloc[start_idx:end_idx] 116 | 117 | # Separate main features and time features. 118 | x = win_df[self.feature_list].values.astype(np.float32) 119 | x_stamp = win_df[self.time_feature_list].values.astype(np.float32) 120 | 121 | # Perform instance-level normalization. 122 | x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) 123 | x = (x - x_mean) / (x_std + 1e-5) 124 | x = np.clip(x, -self.config.clip, self.config.clip) 125 | 126 | # Convert to PyTorch tensors. 127 | x_tensor = torch.from_numpy(x) 128 | x_stamp_tensor = torch.from_numpy(x_stamp) 129 | 130 | return x_tensor, x_stamp_tensor 131 | 132 | 133 | if __name__ == '__main__': 134 | # Example usage and verification. 135 | print("Creating training dataset instance...") 136 | train_dataset = QlibDataset(data_type='train') 137 | 138 | print(f"Dataset length: {len(train_dataset)}") 139 | 140 | if len(train_dataset) > 0: 141 | try_x, try_x_stamp = train_dataset[100] # Index 100 is ignored. 142 | print(f"Sample feature shape: {try_x.shape}") 143 | print(f"Sample time feature shape: {try_x_stamp.shape}") 144 | else: 145 | print("Dataset is empty.") 146 | -------------------------------------------------------------------------------- /finetune/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Config: 4 | """ 5 | Configuration class for the entire project. 6 | """ 7 | 8 | def __init__(self): 9 | # ================================================================= 10 | # Data & Feature Parameters 11 | # ================================================================= 12 | # TODO: Update this path to your Qlib data directory. 13 | self.qlib_data_path = "~/.qlib/qlib_data/cn_data" 14 | self.instrument = 'csi300' 15 | 16 | # Overall time range for data loading from Qlib. 17 | self.dataset_begin_time = "2011-01-01" 18 | self.dataset_end_time = '2025-06-05' 19 | 20 | # Sliding window parameters for creating samples. 21 | self.lookback_window = 90 # Number of past time steps for input. 22 | self.predict_window = 10 # Number of future time steps for prediction. 23 | self.max_context = 512 # Maximum context length for the model. 24 | 25 | # Features to be used from the raw data. 26 | self.feature_list = ['open', 'high', 'low', 'close', 'vol', 'amt'] 27 | # Time-based features to be generated. 28 | self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month'] 29 | 30 | # ================================================================= 31 | # Dataset Splitting & Paths 32 | # ================================================================= 33 | # Note: The validation/test set starts earlier than the training/validation set ends 34 | # to account for the `lookback_window`. 35 | self.train_time_range = ["2011-01-01", "2022-12-31"] 36 | self.val_time_range = ["2022-09-01", "2024-06-30"] 37 | self.test_time_range = ["2024-04-01", "2025-06-05"] 38 | self.backtest_time_range = ["2024-07-01", "2025-06-05"] 39 | 40 | # TODO: Directory to save the processed, pickled datasets. 41 | self.dataset_path = "./data/processed_datasets" 42 | 43 | # ================================================================= 44 | # Training Hyperparameters 45 | # ================================================================= 46 | self.clip = 5.0 # Clipping value for normalized data to prevent outliers. 47 | 48 | self.epochs = 30 49 | self.log_interval = 100 # Log training status every N batches. 50 | self.batch_size = 50 # Batch size per GPU. 51 | 52 | # Number of samples to draw for one "epoch" of training/validation. 53 | # This is useful for large datasets where a true epoch is too long. 54 | self.n_train_iter = 2000 * self.batch_size 55 | self.n_val_iter = 400 * self.batch_size 56 | 57 | # Learning rates for different model components. 58 | self.tokenizer_learning_rate = 2e-4 59 | self.predictor_learning_rate = 4e-5 60 | 61 | # Gradient accumulation to simulate a larger batch size. 62 | self.accumulation_steps = 1 63 | 64 | # AdamW optimizer parameters. 65 | self.adam_beta1 = 0.9 66 | self.adam_beta2 = 0.95 67 | self.adam_weight_decay = 0.1 68 | 69 | # Miscellaneous 70 | self.seed = 100 # Global random seed for reproducibility. 71 | 72 | # ================================================================= 73 | # Experiment Logging & Saving 74 | # ================================================================= 75 | self.use_comet = True # Set to False if you don't want to use Comet ML 76 | self.comet_config = { 77 | # It is highly recommended to load secrets from environment variables 78 | # for security purposes. Example: os.getenv("COMET_API_KEY") 79 | "api_key": "YOUR_COMET_API_KEY", 80 | "project_name": "Kronos-Finetune-Demo", 81 | "workspace": "your_comet_workspace" # TODO: Change to your Comet ML workspace name 82 | } 83 | self.comet_tag = 'finetune_demo' 84 | self.comet_name = 'finetune_demo' 85 | 86 | # Base directory for saving model checkpoints and results. 87 | # Using a general 'outputs' directory is a common practice. 88 | self.save_path = "./outputs/models" 89 | self.tokenizer_save_folder_name = 'finetune_tokenizer_demo' 90 | self.predictor_save_folder_name = 'finetune_predictor_demo' 91 | self.backtest_save_folder_name = 'finetune_backtest_demo' 92 | 93 | # Path for backtesting results. 94 | self.backtest_result_path = "./outputs/backtest_results" 95 | 96 | # ================================================================= 97 | # Model & Checkpoint Paths 98 | # ================================================================= 99 | # TODO: Update these paths to your pretrained model locations. 100 | # These can be local paths or Hugging Face Hub model identifiers. 101 | self.pretrained_tokenizer_path = "path/to/your/Kronos-Tokenizer-base" 102 | self.pretrained_predictor_path = "path/to/your/Kronos-small" 103 | 104 | # Paths to the fine-tuned models, derived from the save_path. 105 | # These will be generated automatically during training. 106 | self.finetuned_tokenizer_path = f"{self.save_path}/{self.tokenizer_save_folder_name}/checkpoints/best_model" 107 | self.finetuned_predictor_path = f"{self.save_path}/{self.predictor_save_folder_name}/checkpoints/best_model" 108 | 109 | # ================================================================= 110 | # Backtesting Parameters 111 | # ================================================================= 112 | self.backtest_n_symbol_hold = 50 # Number of symbols to hold in the portfolio. 113 | self.backtest_n_symbol_drop = 5 # Number of symbols to drop from the pool. 114 | self.backtest_hold_thresh = 5 # Minimum holding period for a stock. 115 | self.inference_T = 0.6 116 | self.inference_top_p = 0.9 117 | self.inference_top_k = 0 118 | self.inference_sample_count = 5 119 | self.backtest_batch_size = 1000 120 | self.backtest_benchmark = self._set_benchmark(self.instrument) 121 | 122 | def _set_benchmark(self, instrument): 123 | dt_benchmark = { 124 | 'csi800': "SH000906", 125 | 'csi1000': "SH000852", 126 | 'csi300': "SH000300", 127 | } 128 | if instrument in dt_benchmark: 129 | return dt_benchmark[instrument] 130 | else: 131 | raise ValueError(f"Benchmark not defined for instrument: {instrument}") 132 | -------------------------------------------------------------------------------- /finetune/train_predictor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | from time import gmtime, strftime 6 | import torch.distributed as dist 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | 12 | import comet_ml 13 | 14 | # Ensure project root is in path 15 | sys.path.append('../') 16 | from config import Config 17 | from dataset import QlibDataset 18 | from model.kronos import KronosTokenizer, Kronos 19 | # Import shared utilities 20 | from utils.training_utils import ( 21 | setup_ddp, 22 | cleanup_ddp, 23 | set_seed, 24 | get_model_size, 25 | format_time 26 | ) 27 | 28 | 29 | def create_dataloaders(config: dict, rank: int, world_size: int): 30 | """ 31 | Creates and returns distributed dataloaders for training and validation. 32 | 33 | Args: 34 | config (dict): A dictionary of configuration parameters. 35 | rank (int): The global rank of the current process. 36 | world_size (int): The total number of processes. 37 | 38 | Returns: 39 | tuple: (train_loader, val_loader, train_dataset, valid_dataset). 40 | """ 41 | print(f"[Rank {rank}] Creating distributed dataloaders...") 42 | train_dataset = QlibDataset('train') 43 | valid_dataset = QlibDataset('val') 44 | print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}") 45 | 46 | train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) 47 | val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False) 48 | 49 | train_loader = DataLoader( 50 | train_dataset, batch_size=config['batch_size'], sampler=train_sampler, 51 | num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=True 52 | ) 53 | val_loader = DataLoader( 54 | valid_dataset, batch_size=config['batch_size'], sampler=val_sampler, 55 | num_workers=config.get('num_workers', 2), pin_memory=True, drop_last=False 56 | ) 57 | return train_loader, val_loader, train_dataset, valid_dataset 58 | 59 | 60 | def train_model(model, tokenizer, device, config, save_dir, logger, rank, world_size): 61 | """ 62 | The main training and validation loop for the predictor. 63 | """ 64 | start_time = time.time() 65 | if rank == 0: 66 | effective_bs = config['batch_size'] * world_size 67 | print(f"Effective BATCHSIZE per GPU: {config['batch_size']}, Total: {effective_bs}") 68 | 69 | train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size) 70 | 71 | optimizer = torch.optim.AdamW( 72 | model.parameters(), 73 | lr=config['predictor_learning_rate'], 74 | betas=(config['adam_beta1'], config['adam_beta2']), 75 | weight_decay=config['adam_weight_decay'] 76 | ) 77 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 78 | optimizer, max_lr=config['predictor_learning_rate'], 79 | steps_per_epoch=len(train_loader), epochs=config['epochs'], 80 | pct_start=0.03, div_factor=10 81 | ) 82 | 83 | best_val_loss = float('inf') 84 | dt_result = {} 85 | batch_idx_global = 0 86 | 87 | for epoch_idx in range(config['epochs']): 88 | epoch_start_time = time.time() 89 | model.train() 90 | train_loader.sampler.set_epoch(epoch_idx) 91 | 92 | train_dataset.set_epoch_seed(epoch_idx * 10000 + rank) 93 | valid_dataset.set_epoch_seed(0) 94 | 95 | for i, (batch_x, batch_x_stamp) in enumerate(train_loader): 96 | batch_x = batch_x.squeeze(0).to(device, non_blocking=True) 97 | batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True) 98 | 99 | # Tokenize input data on-the-fly 100 | with torch.no_grad(): 101 | token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) 102 | 103 | # Prepare inputs and targets for the language model 104 | token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] 105 | token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] 106 | 107 | # Forward pass and loss calculation 108 | logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) 109 | loss, s1_loss, s2_loss = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) 110 | 111 | # Backward pass and optimization 112 | optimizer.zero_grad() 113 | loss.backward() 114 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0) 115 | optimizer.step() 116 | scheduler.step() 117 | 118 | # Logging (Master Process Only) 119 | if rank == 0 and (batch_idx_global + 1) % config['log_interval'] == 0: 120 | lr = optimizer.param_groups[0]['lr'] 121 | print( 122 | f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] " 123 | f"LR {lr:.6f}, Loss: {loss.item():.4f}" 124 | ) 125 | if rank == 0 and logger: 126 | lr = optimizer.param_groups[0]['lr'] 127 | logger.log_metric('train_predictor_loss_batch', loss.item(), step=batch_idx_global) 128 | logger.log_metric('train_S1_loss_each_batch', s1_loss.item(), step=batch_idx_global) 129 | logger.log_metric('train_S2_loss_each_batch', s2_loss.item(), step=batch_idx_global) 130 | logger.log_metric('predictor_learning_rate', lr, step=batch_idx_global) 131 | 132 | batch_idx_global += 1 133 | 134 | # --- Validation Loop --- 135 | model.eval() 136 | tot_val_loss_sum_rank = 0.0 137 | val_batches_processed_rank = 0 138 | with torch.no_grad(): 139 | for batch_x, batch_x_stamp in val_loader: 140 | batch_x = batch_x.squeeze(0).to(device, non_blocking=True) 141 | batch_x_stamp = batch_x_stamp.squeeze(0).to(device, non_blocking=True) 142 | 143 | token_seq_0, token_seq_1 = tokenizer.encode(batch_x, half=True) 144 | token_in = [token_seq_0[:, :-1], token_seq_1[:, :-1]] 145 | token_out = [token_seq_0[:, 1:], token_seq_1[:, 1:]] 146 | 147 | logits = model(token_in[0], token_in[1], batch_x_stamp[:, :-1, :]) 148 | val_loss, _, _ = model.module.head.compute_loss(logits[0], logits[1], token_out[0], token_out[1]) 149 | 150 | tot_val_loss_sum_rank += val_loss.item() 151 | val_batches_processed_rank += 1 152 | 153 | # Reduce validation metrics 154 | val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device) 155 | val_batches_tensor = torch.tensor(val_batches_processed_rank, device=device) 156 | dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM) 157 | dist.all_reduce(val_batches_tensor, op=dist.ReduceOp.SUM) 158 | 159 | avg_val_loss = val_loss_sum_tensor.item() / val_batches_tensor.item() if val_batches_tensor.item() > 0 else 0 160 | 161 | # --- End of Epoch Summary & Checkpointing (Master Process Only) --- 162 | if rank == 0: 163 | print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---") 164 | print(f"Validation Loss: {avg_val_loss:.4f}") 165 | print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}") 166 | print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n") 167 | if logger: 168 | logger.log_metric('val_predictor_loss_epoch', avg_val_loss, epoch=epoch_idx) 169 | 170 | if avg_val_loss < best_val_loss: 171 | best_val_loss = avg_val_loss 172 | save_path = f"{save_dir}/checkpoints/best_model" 173 | model.module.save_pretrained(save_path) 174 | print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})") 175 | 176 | dist.barrier() 177 | 178 | dt_result['best_val_loss'] = best_val_loss 179 | return dt_result 180 | 181 | 182 | def main(config: dict): 183 | """Main function to orchestrate the DDP training process.""" 184 | rank, world_size, local_rank = setup_ddp() 185 | device = torch.device(f"cuda:{local_rank}") 186 | set_seed(config['seed'], rank) 187 | 188 | save_dir = os.path.join(config['save_path'], config['predictor_save_folder_name']) 189 | 190 | # Logger and summary setup (master process only) 191 | comet_logger, master_summary = None, {} 192 | if rank == 0: 193 | os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True) 194 | master_summary = { 195 | 'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()), 196 | 'save_directory': save_dir, 197 | 'world_size': world_size, 198 | } 199 | if config['use_comet']: 200 | comet_logger = comet_ml.Experiment( 201 | api_key=config['comet_config']['api_key'], 202 | project_name=config['comet_config']['project_name'], 203 | workspace=config['comet_config']['workspace'], 204 | ) 205 | comet_logger.add_tag(config['comet_tag']) 206 | comet_logger.set_name(config['comet_name']) 207 | comet_logger.log_parameters(config) 208 | print("Comet Logger Initialized.") 209 | 210 | dist.barrier() 211 | 212 | # Model Initialization 213 | tokenizer = KronosTokenizer.from_pretrained(config['finetuned_tokenizer_path']) 214 | tokenizer.eval().to(device) 215 | 216 | model = Kronos.from_pretrained(config['pretrained_predictor_path']) 217 | model.to(device) 218 | model = DDP(model, device_ids=[local_rank], find_unused_parameters=False) 219 | 220 | if rank == 0: 221 | print(f"Predictor Model Size: {get_model_size(model.module)}") 222 | 223 | # Start Training 224 | dt_result = train_model( 225 | model, tokenizer, device, config, save_dir, comet_logger, rank, world_size 226 | ) 227 | 228 | if rank == 0: 229 | master_summary['final_result'] = dt_result 230 | with open(os.path.join(save_dir, 'summary.json'), 'w') as f: 231 | json.dump(master_summary, f, indent=4) 232 | print('Training finished. Summary file saved.') 233 | if comet_logger: comet_logger.end() 234 | 235 | cleanup_ddp() 236 | 237 | 238 | if __name__ == '__main__': 239 | # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_predictor.py 240 | if "WORLD_SIZE" not in os.environ: 241 | raise RuntimeError("This script must be launched with `torchrun`.") 242 | 243 | config_instance = Config() 244 | main(config_instance.__dict__) 245 | -------------------------------------------------------------------------------- /finetune/train_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | from time import gmtime, strftime 6 | import argparse 7 | import datetime 8 | import torch.distributed as dist 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.distributed import DistributedSampler 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | 15 | import comet_ml 16 | 17 | # Ensure project root is in path 18 | sys.path.append("../") 19 | from config import Config 20 | from dataset import QlibDataset 21 | from model.kronos import KronosTokenizer 22 | # Import shared utilities 23 | from utils.training_utils import ( 24 | setup_ddp, 25 | cleanup_ddp, 26 | set_seed, 27 | get_model_size, 28 | format_time, 29 | ) 30 | 31 | 32 | def create_dataloaders(config: dict, rank: int, world_size: int): 33 | """ 34 | Creates and returns distributed dataloaders for training and validation. 35 | 36 | Args: 37 | config (dict): A dictionary of configuration parameters. 38 | rank (int): The global rank of the current process. 39 | world_size (int): The total number of processes. 40 | 41 | Returns: 42 | tuple: A tuple containing (train_loader, val_loader, train_dataset, valid_dataset). 43 | """ 44 | print(f"[Rank {rank}] Creating distributed dataloaders...") 45 | train_dataset = QlibDataset('train') 46 | valid_dataset = QlibDataset('val') 47 | print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}") 48 | 49 | train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) 50 | val_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, rank=rank, shuffle=False) 51 | 52 | train_loader = DataLoader( 53 | train_dataset, 54 | batch_size=config['batch_size'], 55 | sampler=train_sampler, 56 | shuffle=False, # Shuffle is handled by the sampler 57 | num_workers=config.get('num_workers', 2), 58 | pin_memory=True, 59 | drop_last=True 60 | ) 61 | val_loader = DataLoader( 62 | valid_dataset, 63 | batch_size=config['batch_size'], 64 | sampler=val_sampler, 65 | shuffle=False, 66 | num_workers=config.get('num_workers', 2), 67 | pin_memory=True, 68 | drop_last=False 69 | ) 70 | print(f"[Rank {rank}] Dataloaders created. Train steps/epoch: {len(train_loader)}, Val steps: {len(val_loader)}") 71 | return train_loader, val_loader, train_dataset, valid_dataset 72 | 73 | 74 | def train_model(model, device, config, save_dir, logger, rank, world_size): 75 | """ 76 | The main training and validation loop for the tokenizer. 77 | 78 | Args: 79 | model (DDP): The DDP-wrapped model to train. 80 | device (torch.device): The device for the current process. 81 | config (dict): Configuration dictionary. 82 | save_dir (str): Directory to save checkpoints. 83 | logger (comet_ml.Experiment): Comet logger instance. 84 | rank (int): Global rank of the process. 85 | world_size (int): Total number of processes. 86 | 87 | Returns: 88 | tuple: A tuple containing the trained model and a dictionary of results. 89 | """ 90 | start_time = time.time() 91 | if rank == 0: 92 | effective_bs = config['batch_size'] * world_size * config['accumulation_steps'] 93 | print(f"[Rank {rank}] BATCHSIZE (per GPU): {config['batch_size']}") 94 | print(f"[Rank {rank}] Effective total batch size: {effective_bs}") 95 | 96 | train_loader, val_loader, train_dataset, valid_dataset = create_dataloaders(config, rank, world_size) 97 | 98 | optimizer = torch.optim.AdamW( 99 | model.parameters(), 100 | lr=config['tokenizer_learning_rate'], 101 | weight_decay=config['adam_weight_decay'] 102 | ) 103 | 104 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 105 | optimizer=optimizer, 106 | max_lr=config['tokenizer_learning_rate'], 107 | steps_per_epoch=len(train_loader), 108 | epochs=config['epochs'], 109 | pct_start=0.03, 110 | div_factor=10 111 | ) 112 | 113 | best_val_loss = float('inf') 114 | dt_result = {} 115 | batch_idx_global_train = 0 116 | 117 | for epoch_idx in range(config['epochs']): 118 | epoch_start_time = time.time() 119 | model.train() 120 | train_loader.sampler.set_epoch(epoch_idx) 121 | 122 | # Set dataset seeds for reproducible sampling 123 | train_dataset.set_epoch_seed(epoch_idx * 10000 + rank) 124 | valid_dataset.set_epoch_seed(0) # Keep validation sampling consistent 125 | 126 | for i, (ori_batch_x, _) in enumerate(train_loader): 127 | ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) 128 | 129 | # --- Gradient Accumulation Loop --- 130 | current_batch_total_loss = 0.0 131 | for j in range(config['accumulation_steps']): 132 | start_idx = j * (ori_batch_x.shape[0] // config['accumulation_steps']) 133 | end_idx = (j + 1) * (ori_batch_x.shape[0] // config['accumulation_steps']) 134 | batch_x = ori_batch_x[start_idx:end_idx] 135 | 136 | # Forward pass 137 | zs, bsq_loss, _, _ = model(batch_x) 138 | z_pre, z = zs 139 | 140 | # Loss calculation 141 | recon_loss_pre = F.mse_loss(z_pre, batch_x) 142 | recon_loss_all = F.mse_loss(z, batch_x) 143 | recon_loss = recon_loss_pre + recon_loss_all 144 | loss = (recon_loss + bsq_loss) / 2 # Assuming w_1=w_2=1 145 | 146 | loss_scaled = loss / config['accumulation_steps'] 147 | current_batch_total_loss += loss.item() 148 | loss_scaled.backward() 149 | 150 | # --- Optimizer Step after Accumulation --- 151 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0) 152 | optimizer.step() 153 | scheduler.step() 154 | optimizer.zero_grad() 155 | 156 | # --- Logging (Master Process Only) --- 157 | if rank == 0 and (batch_idx_global_train + 1) % config['log_interval'] == 0: 158 | avg_loss = current_batch_total_loss / config['accumulation_steps'] 159 | print( 160 | f"[Rank {rank}, Epoch {epoch_idx + 1}/{config['epochs']}, Step {i + 1}/{len(train_loader)}] " 161 | f"LR {optimizer.param_groups[0]['lr']:.6f}, Loss: {avg_loss:.4f}" 162 | ) 163 | if rank == 0 and logger: 164 | avg_loss = current_batch_total_loss / config['accumulation_steps'] 165 | logger.log_metric('train_tokenizer_loss_batch', avg_loss, step=batch_idx_global_train) 166 | logger.log_metric(f'train_vqvae_vq_loss_each_batch', bsq_loss.item(), step=batch_idx_global_train) 167 | logger.log_metric(f'train_recon_loss_pre_each_batch', recon_loss_pre.item(), step=batch_idx_global_train) 168 | logger.log_metric(f'train_recon_loss_each_batch', recon_loss_all.item(), step=batch_idx_global_train) 169 | logger.log_metric('tokenizer_learning_rate', optimizer.param_groups[0]["lr"], step=batch_idx_global_train) 170 | 171 | batch_idx_global_train += 1 172 | 173 | # --- Validation Loop --- 174 | model.eval() 175 | tot_val_loss_sum_rank = 0.0 176 | val_sample_count_rank = 0 177 | with torch.no_grad(): 178 | for ori_batch_x, _ in val_loader: 179 | ori_batch_x = ori_batch_x.squeeze(0).to(device, non_blocking=True) 180 | zs, _, _, _ = model(ori_batch_x) 181 | _, z = zs 182 | val_loss_item = F.mse_loss(z, ori_batch_x) 183 | 184 | tot_val_loss_sum_rank += val_loss_item.item() * ori_batch_x.size(0) 185 | val_sample_count_rank += ori_batch_x.size(0) 186 | 187 | # Reduce validation losses from all processes 188 | val_loss_sum_tensor = torch.tensor(tot_val_loss_sum_rank, device=device) 189 | val_count_tensor = torch.tensor(val_sample_count_rank, device=device) 190 | dist.all_reduce(val_loss_sum_tensor, op=dist.ReduceOp.SUM) 191 | dist.all_reduce(val_count_tensor, op=dist.ReduceOp.SUM) 192 | 193 | avg_val_loss = val_loss_sum_tensor.item() / val_count_tensor.item() if val_count_tensor.item() > 0 else 0 194 | 195 | # --- End of Epoch Summary & Checkpointing (Master Process Only) --- 196 | if rank == 0: 197 | print(f"\n--- Epoch {epoch_idx + 1}/{config['epochs']} Summary ---") 198 | print(f"Validation Loss: {avg_val_loss:.4f}") 199 | print(f"Time This Epoch: {format_time(time.time() - epoch_start_time)}") 200 | print(f"Total Time Elapsed: {format_time(time.time() - start_time)}\n") 201 | if logger: 202 | logger.log_metric('val_tokenizer_loss_epoch', avg_val_loss, epoch=epoch_idx) 203 | 204 | if avg_val_loss < best_val_loss: 205 | best_val_loss = avg_val_loss 206 | save_path = f"{save_dir}/checkpoints/best_model" 207 | model.module.save_pretrained(save_path) 208 | print(f"Best model saved to {save_path} (Val Loss: {best_val_loss:.4f})") 209 | if logger: 210 | logger.log_model("best_model", save_path) 211 | 212 | dist.barrier() # Ensure all processes finish the epoch before starting the next one. 213 | 214 | dt_result['best_val_loss'] = best_val_loss 215 | return model, dt_result 216 | 217 | 218 | def main(config: dict): 219 | """ 220 | Main function to orchestrate the DDP training process. 221 | """ 222 | rank, world_size, local_rank = setup_ddp() 223 | device = torch.device(f"cuda:{local_rank}") 224 | set_seed(config['seed'], rank) 225 | 226 | save_dir = os.path.join(config['save_path'], config['tokenizer_save_folder_name']) 227 | 228 | # Logger and summary setup (master process only) 229 | comet_logger, master_summary = None, {} 230 | if rank == 0: 231 | os.makedirs(os.path.join(save_dir, 'checkpoints'), exist_ok=True) 232 | master_summary = { 233 | 'start_time': strftime("%Y-%m-%dT%H-%M-%S", gmtime()), 234 | 'save_directory': save_dir, 235 | 'world_size': world_size, 236 | } 237 | if config['use_comet']: 238 | comet_logger = comet_ml.Experiment( 239 | api_key=config['comet_config']['api_key'], 240 | project_name=config['comet_config']['project_name'], 241 | workspace=config['comet_config']['workspace'], 242 | ) 243 | comet_logger.add_tag(config['comet_tag']) 244 | comet_logger.set_name(config['comet_name']) 245 | comet_logger.log_parameters(config) 246 | print("Comet Logger Initialized.") 247 | 248 | dist.barrier() # Ensure save directory is created before proceeding 249 | 250 | # Model Initialization 251 | model = KronosTokenizer.from_pretrained(config['pretrained_tokenizer_path']) 252 | model.to(device) 253 | model = DDP(model, device_ids=[local_rank], find_unused_parameters=False) 254 | 255 | if rank == 0: 256 | print(f"Model Size: {get_model_size(model.module)}") 257 | 258 | # Start Training 259 | _, dt_result = train_model( 260 | model, device, config, save_dir, comet_logger, rank, world_size 261 | ) 262 | 263 | # Finalize and save summary (master process only) 264 | if rank == 0: 265 | master_summary['final_result'] = dt_result 266 | with open(os.path.join(save_dir, 'summary.json'), 'w') as f: 267 | json.dump(master_summary, f, indent=4) 268 | print('Training finished. Summary file saved.') 269 | if comet_logger: 270 | comet_logger.end() 271 | 272 | cleanup_ddp() 273 | 274 | 275 | if __name__ == '__main__': 276 | # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_tokenizer.py 277 | if "WORLD_SIZE" not in os.environ: 278 | raise RuntimeError("This script must be launched with `torchrun`.") 279 | 280 | config_instance = Config() 281 | main(config_instance.__dict__) 282 | -------------------------------------------------------------------------------- /finetune/qlib_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import pickle 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader 11 | from tqdm import trange, tqdm 12 | from matplotlib import pyplot as plt 13 | 14 | import qlib 15 | from qlib.config import REG_CN 16 | from qlib.backtest import backtest, executor, CommonInfrastructure 17 | from qlib.contrib.evaluate import risk_analysis 18 | from qlib.contrib.strategy import TopkDropoutStrategy 19 | from qlib.utils import flatten_dict 20 | from qlib.utils.time import Freq 21 | 22 | # Ensure project root is in the Python path 23 | sys.path.append("../") 24 | from config import Config 25 | from model.kronos import Kronos, KronosTokenizer, auto_regressive_inference 26 | 27 | 28 | # ================================================================================= 29 | # 1. Data Loading and Processing for Inference 30 | # ================================================================================= 31 | 32 | class QlibTestDataset(Dataset): 33 | """ 34 | PyTorch Dataset for handling Qlib test data, specifically for inference. 35 | 36 | This dataset iterates through all possible sliding windows sequentially. It also 37 | yields metadata like symbol and timestamp, which are crucial for mapping 38 | predictions back to the original time series. 39 | """ 40 | 41 | def __init__(self, data: dict, config: Config): 42 | self.data = data 43 | self.config = config 44 | self.window_size = config.lookback_window + config.predict_window 45 | self.symbols = list(self.data.keys()) 46 | self.feature_list = config.feature_list 47 | self.time_feature_list = config.time_feature_list 48 | self.indices = [] 49 | 50 | print("Preprocessing and building indices for test dataset...") 51 | for symbol in self.symbols: 52 | df = self.data[symbol].reset_index() 53 | # Generate time features on-the-fly 54 | df['minute'] = df['datetime'].dt.minute 55 | df['hour'] = df['datetime'].dt.hour 56 | df['weekday'] = df['datetime'].dt.weekday 57 | df['day'] = df['datetime'].dt.day 58 | df['month'] = df['datetime'].dt.month 59 | self.data[symbol] = df # Store preprocessed dataframe 60 | 61 | num_samples = len(df) - self.window_size + 1 62 | if num_samples > 0: 63 | for i in range(num_samples): 64 | timestamp = df.iloc[i + self.config.lookback_window - 1]['datetime'] 65 | self.indices.append((symbol, i, timestamp)) 66 | 67 | def __len__(self) -> int: 68 | return len(self.indices) 69 | 70 | def __getitem__(self, idx: int): 71 | symbol, start_idx, timestamp = self.indices[idx] 72 | df = self.data[symbol] 73 | 74 | context_end = start_idx + self.config.lookback_window 75 | predict_end = context_end + self.config.predict_window 76 | 77 | context_df = df.iloc[start_idx:context_end] 78 | predict_df = df.iloc[context_end:predict_end] 79 | 80 | x = context_df[self.feature_list].values.astype(np.float32) 81 | x_stamp = context_df[self.time_feature_list].values.astype(np.float32) 82 | y_stamp = predict_df[self.time_feature_list].values.astype(np.float32) 83 | 84 | # Instance-level normalization, consistent with training 85 | x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) 86 | x = (x - x_mean) / (x_std + 1e-5) 87 | x = np.clip(x, -self.config.clip, self.config.clip) 88 | 89 | return torch.from_numpy(x), torch.from_numpy(x_stamp), torch.from_numpy(y_stamp), symbol, timestamp 90 | 91 | 92 | # ================================================================================= 93 | # 2. Backtesting Logic 94 | # ================================================================================= 95 | 96 | class QlibBacktest: 97 | """ 98 | A wrapper class for conducting backtesting experiments using Qlib. 99 | """ 100 | 101 | def __init__(self, config: Config): 102 | self.config = config 103 | self.initialize_qlib() 104 | 105 | def initialize_qlib(self): 106 | """Initializes the Qlib environment.""" 107 | print("Initializing Qlib for backtesting...") 108 | qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN) 109 | 110 | def run_single_backtest(self, signal_series: pd.Series) -> pd.DataFrame: 111 | """ 112 | Runs a single backtest for a given prediction signal. 113 | 114 | Args: 115 | signal_series (pd.Series): A pandas Series with a MultiIndex 116 | (instrument, datetime) and prediction scores. 117 | Returns: 118 | pd.DataFrame: A DataFrame containing the performance report. 119 | """ 120 | strategy = TopkDropoutStrategy( 121 | topk=self.config.backtest_n_symbol_hold, 122 | n_drop=self.config.backtest_n_symbol_drop, 123 | hold_thresh=self.config.backtest_hold_thresh, 124 | signal=signal_series, 125 | ) 126 | executor_config = { 127 | "time_per_step": "day", 128 | "generate_portfolio_metrics": True, 129 | "delay_execution": True, 130 | } 131 | backtest_config = { 132 | "start_time": self.config.backtest_time_range[0], 133 | "end_time": self.config.backtest_time_range[1], 134 | "account": 100_000_000, 135 | "benchmark": self.config.backtest_benchmark, 136 | "exchange_kwargs": { 137 | "freq": "day", "limit_threshold": 0.095, "deal_price": "open", 138 | "open_cost": 0.001, "close_cost": 0.0015, "min_cost": 5, 139 | }, 140 | "executor": executor.SimulatorExecutor(**executor_config), 141 | } 142 | 143 | portfolio_metric_dict, _ = backtest(strategy=strategy, **backtest_config) 144 | analysis_freq = "{0}{1}".format(*Freq.parse("day")) 145 | report, _ = portfolio_metric_dict.get(analysis_freq) 146 | 147 | # --- Analysis and Reporting --- 148 | analysis = { 149 | "excess_return_without_cost": risk_analysis(report["return"] - report["bench"], freq=analysis_freq), 150 | "excess_return_with_cost": risk_analysis(report["return"] - report["bench"] - report["cost"], freq=analysis_freq), 151 | } 152 | print("\n--- Backtest Analysis ---") 153 | print("Benchmark Return:", risk_analysis(report["bench"], freq=analysis_freq), sep='\n') 154 | print("\nExcess Return (w/o cost):", analysis["excess_return_without_cost"], sep='\n') 155 | print("\nExcess Return (w/ cost):", analysis["excess_return_with_cost"], sep='\n') 156 | 157 | report_df = pd.DataFrame({ 158 | "cum_bench": report["bench"].cumsum(), 159 | "cum_return_w_cost": (report["return"] - report["cost"]).cumsum(), 160 | "cum_ex_return_w_cost": (report["return"] - report["bench"] - report["cost"]).cumsum(), 161 | }) 162 | return report_df 163 | 164 | def run_and_plot_results(self, signals: dict[str, pd.DataFrame]): 165 | """ 166 | Runs backtests for multiple signals and plots the cumulative return curves. 167 | 168 | Args: 169 | signals (dict[str, pd.DataFrame]): A dictionary where keys are signal names 170 | and values are prediction DataFrames. 171 | """ 172 | return_df, ex_return_df, bench_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame() 173 | 174 | for signal_name, pred_df in signals.items(): 175 | print(f"\nBacktesting signal: {signal_name}...") 176 | pred_series = pred_df.stack() 177 | pred_series.index.names = ['datetime', 'instrument'] 178 | pred_series = pred_series.swaplevel().sort_index() 179 | report_df = self.run_single_backtest(pred_series) 180 | 181 | return_df[signal_name] = report_df['cum_return_w_cost'] 182 | ex_return_df[signal_name] = report_df['cum_ex_return_w_cost'] 183 | if 'return' not in bench_df: 184 | bench_df['return'] = report_df['cum_bench'] 185 | 186 | # Plotting results 187 | fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True) 188 | return_df.plot(ax=axes[0], title='Cumulative Return with Cost', grid=True) 189 | axes[0].plot(bench_df['return'], label=self.config.instrument.upper(), color='black', linestyle='--') 190 | axes[0].legend() 191 | axes[0].set_ylabel("Cumulative Return") 192 | 193 | ex_return_df.plot(ax=axes[1], title='Cumulative Excess Return with Cost', grid=True) 194 | axes[1].legend() 195 | axes[1].set_xlabel("Date") 196 | axes[1].set_ylabel("Cumulative Excess Return") 197 | 198 | plt.tight_layout() 199 | plt.savefig("../figures/backtest_result_example.png", dpi=200) 200 | plt.show() 201 | 202 | 203 | # ================================================================================= 204 | # 3. Inference Logic 205 | # ================================================================================= 206 | 207 | def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]: 208 | """Loads the fine-tuned tokenizer and predictor model.""" 209 | device = torch.device(config['device']) 210 | print(f"Loading models onto device: {device}...") 211 | tokenizer = KronosTokenizer.from_pretrained(config['tokenizer_path']).to(device).eval() 212 | model = Kronos.from_pretrained(config['model_path']).to(device).eval() 213 | return tokenizer, model 214 | 215 | 216 | def collate_fn_for_inference(batch): 217 | """ 218 | Custom collate function to handle batches containing Tensors, strings, and Timestamps. 219 | 220 | Args: 221 | batch (list): A list of samples, where each sample is the tuple returned by 222 | QlibTestDataset.__getitem__. 223 | 224 | Returns: 225 | A single tuple containing the batched data. 226 | """ 227 | # Unzip the list of samples into separate lists for each data type 228 | x, x_stamp, y_stamp, symbols, timestamps = zip(*batch) 229 | 230 | # Stack the tensors to create a batch 231 | x_batch = torch.stack(x, dim=0) 232 | x_stamp_batch = torch.stack(x_stamp, dim=0) 233 | y_stamp_batch = torch.stack(y_stamp, dim=0) 234 | 235 | # Return the strings and timestamps as lists 236 | return x_batch, x_stamp_batch, y_stamp_batch, list(symbols), list(timestamps) 237 | 238 | 239 | def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFrame]: 240 | """ 241 | Runs inference on the test dataset to generate prediction signals. 242 | 243 | Args: 244 | config (dict): A dictionary containing inference parameters. 245 | test_data (dict): The raw test data loaded from a pickle file. 246 | 247 | Returns: 248 | A dictionary where keys are signal types (e.g., 'mean', 'last') and 249 | values are DataFrames of predictions (datetime index, symbol columns). 250 | """ 251 | tokenizer, model = load_models(config) 252 | device = torch.device(config['device']) 253 | 254 | # Use the Dataset and DataLoader for efficient batching and processing 255 | dataset = QlibTestDataset(data=test_data, config=Config()) 256 | loader = DataLoader( 257 | dataset, 258 | batch_size=config['batch_size'] // config['sample_count'], 259 | shuffle=False, 260 | num_workers=os.cpu_count() // 2, 261 | collate_fn=collate_fn_for_inference 262 | ) 263 | 264 | results = defaultdict(list) 265 | with torch.no_grad(): 266 | for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"): 267 | preds = auto_regressive_inference( 268 | tokenizer, model, x.to(device), x_stamp.to(device), y_stamp.to(device), 269 | max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'], 270 | T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count'] 271 | ) 272 | 273 | # The 'close' price is at index 3 in `feature_list` 274 | last_day_close = x[:, -1, 3].numpy() 275 | signals = { 276 | 'last': preds[:, -1, 3] - last_day_close, 277 | 'mean': np.mean(preds[:, :, 3], axis=1) - last_day_close, 278 | 'max': np.max(preds[:, :, 3], axis=1) - last_day_close, 279 | 'min': np.min(preds[:, :, 3], axis=1) - last_day_close, 280 | } 281 | 282 | for i in range(len(symbols)): 283 | for sig_type, sig_values in signals.items(): 284 | results[sig_type].append((timestamps[i], symbols[i], sig_values[i])) 285 | 286 | print("Post-processing predictions into DataFrames...") 287 | prediction_dfs = {} 288 | for sig_type, records in results.items(): 289 | df = pd.DataFrame(records, columns=['datetime', 'instrument', 'score']) 290 | pivot_df = df.pivot_table(index='datetime', columns='instrument', values='score') 291 | prediction_dfs[sig_type] = pivot_df.sort_index() 292 | 293 | return prediction_dfs 294 | 295 | 296 | # ================================================================================= 297 | # 4. Main Execution 298 | # ================================================================================= 299 | 300 | def main(): 301 | """Main function to set up config, run inference, and execute backtesting.""" 302 | parser = argparse.ArgumentParser(description="Run Kronos Inference and Backtesting") 303 | parser.add_argument("--device", type=str, default="cuda:1", help="Device for inference (e.g., 'cuda:0', 'cpu')") 304 | args = parser.parse_args() 305 | 306 | # --- 1. Configuration Setup --- 307 | base_config = Config() 308 | 309 | # Create a dedicated dictionary for this run's configuration 310 | run_config = { 311 | 'device': args.device, 312 | 'data_path': base_config.dataset_path, 313 | 'result_save_path': base_config.backtest_result_path, 314 | 'result_name': base_config.backtest_save_folder_name, 315 | 'tokenizer_path': base_config.finetuned_tokenizer_path, 316 | 'model_path': base_config.finetuned_predictor_path, 317 | 'max_context': base_config.max_context, 318 | 'pred_len': base_config.predict_window, 319 | 'clip': base_config.clip, 320 | 'T': base_config.inference_T, 321 | 'top_k': base_config.inference_top_k, 322 | 'top_p': base_config.inference_top_p, 323 | 'sample_count': base_config.inference_sample_count, 324 | 'batch_size': base_config.backtest_batch_size, 325 | } 326 | 327 | print("--- Running with Configuration ---") 328 | for key, val in run_config.items(): 329 | print(f"{key:>20}: {val}") 330 | print("-" * 35) 331 | 332 | # --- 2. Load Data --- 333 | test_data_path = os.path.join(run_config['data_path'], "test_data.pkl") 334 | print(f"Loading test data from {test_data_path}...") 335 | with open(test_data_path, 'rb') as f: 336 | test_data = pickle.load(f) 337 | print(test_data) 338 | # --- 3. Generate Predictions --- 339 | model_preds = generate_predictions(run_config, test_data) 340 | 341 | # --- 4. Save Predictions --- 342 | save_dir = os.path.join(run_config['result_save_path'], run_config['result_name']) 343 | os.makedirs(save_dir, exist_ok=True) 344 | predictions_file = os.path.join(save_dir, "predictions.pkl") 345 | print(f"Saving prediction signals to {predictions_file}...") 346 | with open(predictions_file, 'wb') as f: 347 | pickle.dump(model_preds, f) 348 | 349 | # --- 5. Run Backtesting --- 350 | with open(predictions_file, 'rb') as f: 351 | model_preds = pickle.load(f) 352 | 353 | backtester = QlibBacktest(base_config) 354 | backtester.run_and_plot_results(model_preds) 355 | 356 | 357 | if __name__ == '__main__': 358 | main() 359 | -------------------------------------------------------------------------------- /facecat/stock.py: -------------------------------------------------------------------------------- 1 | from facecat import * 2 | from datetime import datetime 3 | 4 | class ClientTickDataCache: 5 | def __init__(self): 6 | self.code = "" # 初始化代码 7 | self.lastAmount = 0 # 初始化上次成交额 8 | self.lastDate = 0 # 初始化上次日期 9 | self.lastVolume = 0 # 初始化上次成交量 10 | 11 | class ADJUSTMENTFACTOR: 12 | def __init__(self): 13 | self.dwDate = 0 # 初始化日期 14 | self.f1 = 0 # 每10股派现 15 | self.f2 = 0 # 配股价 16 | self.f3 = 0 # 每10股送股 17 | self.f4 = 0 # 每10股配股 18 | 19 | def getDateNum(year, month, day, hour, minute, second, millisecond): 20 | """ 获取日期的时间戳 21 | year 年份 22 | month 月份 23 | day 日 24 | hour 小时 25 | minute 分钟 26 | second 秒 27 | millisecond 毫秒 28 | @returns 返回日期的时间戳""" 29 | date = datetime(year, month, day, hour, minute, second, millisecond) 30 | return int(date.timestamp()) 31 | 32 | def numToDate(num): 33 | """ 时间戳转日期 34 | num 时间戳 35 | @returns 返回日期对象""" 36 | date = datetime.fromtimestamp(num) 37 | return date 38 | 39 | def getSeason(month): 40 | """ 获取季度 41 | month 月份 42 | @returns 返回季度""" 43 | if 1 <= month <= 3: 44 | return 1 45 | elif 4 <= month <= 6: 46 | return 2 47 | elif 7 <= month <= 9: 48 | return 3 49 | else: 50 | return 4 51 | 52 | def copySecurityData(data): 53 | """ 拷贝数据 54 | data 原来的数据 55 | @returns 新数据""" 56 | newData = SecurityData() 57 | newData.date = data.date 58 | newData.high = data.high 59 | newData.low = data.low 60 | newData.open = data.open 61 | newData.close = data.close 62 | newData.amount = data.amount 63 | newData.volume = data.volume 64 | return newData 65 | 66 | def multiMinuteSecurityDatas(newDatas, minuteDatas, cycle): 67 | """ 多分钟数据处理 68 | newDatas 新数据数组 69 | minuteDatas 分钟数据数组 70 | cycle 周期""" 71 | lastMinutes = 0 72 | for minuteData in minuteDatas: 73 | minutes = minuteData.date // 60 74 | if lastMinutes == 0: 75 | lastMinutes = minutes 76 | # 更新 77 | if newDatas and minutes - lastMinutes < cycle: 78 | lastData = newDatas[len(newDatas) - 1] 79 | lastData.close = minuteData.close 80 | if minuteData.high > lastData.high: 81 | lastData.high = minuteData.high 82 | if minuteData.low < lastData.low: 83 | lastData.low = minuteData.low 84 | lastData.amount += minuteData.amount 85 | lastData.volume += minuteData.volume 86 | else: 87 | newData = copySecurityData(minuteData) 88 | newDatas.append(newData) 89 | lastMinutes = minutes 90 | 91 | def getHistoryWeekDatas(weekDatas, dayDatas): 92 | """ 获取历史周数据 93 | weekDatas 周数据数组 94 | dayDatas 日数据数组 95 | @returns 返回操作结果""" 96 | dayDatasSize = len(dayDatas) 97 | if dayDatasSize > 0: 98 | firstDate = getDateNum(1970, 1, 5, 0, 0, 0, 0) 99 | weekData = copySecurityData(dayDatas[0]) 100 | lWeeks = (weekData.date - firstDate) // 86400 // 7 101 | for i in range(dayDatasSize): 102 | dayData = copySecurityData(dayDatas[i]) 103 | weeks = (dayData.date - firstDate) // 86400 // 7 104 | isNextWeek = weeks > lWeeks 105 | if isNextWeek: 106 | weekDatas.append(weekData) 107 | weekData = copySecurityData(dayData) 108 | if i == dayDatasSize - 1: 109 | weekDatas.append(weekData) 110 | else: 111 | if i > 0: 112 | weekData.close = dayData.close 113 | weekData.amount += dayData.amount 114 | weekData.volume += dayData.volume 115 | if weekData.high < dayData.high: 116 | weekData.high = dayData.high 117 | if weekData.low > dayData.low: 118 | weekData.low = dayData.low 119 | if i == dayDatasSize - 1: 120 | weekDatas.append(weekData) 121 | lWeeks = weeks 122 | return 1 123 | 124 | def getHistoryMonthDatas(monthDatas, dayDatas): 125 | """ 获取历史月数据 126 | monthDatas 月数据数组 127 | dayDatas 日数据数组 128 | 返回操作结果""" 129 | dayDatasSize = len(dayDatas) 130 | if dayDatasSize > 0: 131 | monthData = copySecurityData(dayDatas[0]) 132 | ldate = numToDate(monthData.date) 133 | lYear = ldate.year 134 | lMonth = ldate.month 135 | lDay = ldate.day 136 | for i in range(dayDatasSize): 137 | dayData = copySecurityData(dayDatas[i]) 138 | date = numToDate(dayData.date) 139 | year = date.year 140 | month = date.month 141 | day = date.day 142 | isNextMonth = year * 12 + month > lYear * 12 + lMonth 143 | if isNextMonth: 144 | monthDatas.append(monthData) 145 | monthData = copySecurityData(dayData) 146 | if i == dayDatasSize - 1: 147 | monthDatas.append(monthData) 148 | else: 149 | if i > 0: 150 | monthData.close = dayData.close 151 | monthData.amount += dayData.amount 152 | monthData.volume += dayData.volume 153 | if monthData.high < dayData.high: 154 | monthData.high = dayData.high 155 | if monthData.low > dayData.low: 156 | monthData.low = dayData.low 157 | if i == dayDatasSize - 1: 158 | monthDatas.append(monthData) 159 | lYear = year 160 | lMonth = month 161 | lDay = day 162 | return 1 163 | 164 | def getHistorySeasonDatas(seasonDatas, dayDatas): 165 | """ 获取历史季节数据 166 | seasonDatas 季节数据数组 167 | dayDatas 日数据数组 168 | @returns 返回操作结果""" 169 | dayDatasSize = len(dayDatas) 170 | if dayDatasSize > 0: 171 | seasonData = copySecurityData(dayDatas[0]) 172 | ldate = numToDate(seasonData.date) 173 | lYear = ldate.year 174 | lMonth = ldate.month 175 | lDay = ldate.day 176 | for i in range(dayDatasSize): 177 | dayData = copySecurityData(dayDatas[i]) 178 | date = numToDate(dayData.date) 179 | year = date.year 180 | month = date.month 181 | day = date.day 182 | isNextSeason = year * 4 + getSeason(month) > lYear * 4 + getSeason(lMonth) 183 | if isNextSeason: 184 | seasonDatas.append(seasonData) 185 | seasonData = copySecurityData(dayData) 186 | if i == dayDatasSize - 1: 187 | seasonDatas.append(seasonData) 188 | else: 189 | if i > 0: 190 | seasonData.close = dayData.close 191 | seasonData.amount += dayData.amount 192 | seasonData.volume += dayData.volume 193 | if seasonData.high < dayData.high: 194 | seasonData.high = dayData.high 195 | if seasonData.low > dayData.low: 196 | seasonData.low = dayData.low 197 | if i == dayDatasSize - 1: 198 | seasonDatas.append(seasonData) 199 | lYear = year 200 | lMonth = month 201 | lDay = day 202 | return 1 203 | 204 | def getHistoryHalfYearDatas(halfYearDatas, dayDatas): 205 | """ 获取历史半年数据 206 | halfYearDatas 半年数据数组 207 | dayDatas 日数据数组 208 | @returns 返回操作结果""" 209 | dayDatasSize = len(dayDatas) 210 | if dayDatasSize > 0: 211 | yearData = copySecurityData(dayDatas[0]) 212 | ldate = numToDate(yearData.date) 213 | lyear = ldate.year 214 | lmonth = ldate.month 215 | for i in range(dayDatasSize): 216 | dayData = copySecurityData(dayDatas[i]) 217 | date = numToDate(dayData.date) 218 | year = date.year 219 | month = date.month 220 | isNextHalfYear = year * 2 + month // 6 > lyear * 2 + lmonth // 6 221 | if isNextHalfYear: 222 | halfYearDatas.append(yearData) 223 | yearData = copySecurityData(dayData) 224 | if i == dayDatasSize - 1: 225 | halfYearDatas.append(yearData) 226 | else: 227 | if i > 0: 228 | yearData.close = dayData.close 229 | yearData.amount += dayData.amount 230 | yearData.volume += dayData.volume 231 | if yearData.high < dayData.high: 232 | yearData.high = dayData.high 233 | if yearData.low > dayData.low: 234 | yearData.low = dayData.low 235 | if i == dayDatasSize - 1: 236 | halfYearDatas.append(yearData) 237 | lyear = year 238 | lmonth = month 239 | return 1 240 | 241 | def getHistoryYearDatas(yearDatas, dayDatas): 242 | """ 获取历史年数据 243 | yearDatas 年数据数组 244 | dayDatas 日数据数组 245 | @returns 返回操作结果""" 246 | dayDatasSize = len(dayDatas) 247 | if dayDatasSize > 0: 248 | yearData = copySecurityData(dayDatas[0]) 249 | ldate = numToDate(yearData.date) 250 | lyear = ldate.year 251 | lmonth = ldate.month 252 | for i in range(dayDatasSize): 253 | dayData = copySecurityData(dayDatas[i]) 254 | date = numToDate(dayData.date) 255 | year = date.year 256 | month = date.month 257 | isNextYear = year > lyear 258 | if isNextYear: 259 | yearDatas.append(yearData) 260 | yearData = copySecurityData(dayData) 261 | if i == dayDatasSize - 1: 262 | yearDatas.append(yearData) 263 | else: 264 | if i > 0: 265 | yearData.close = dayData.close 266 | yearData.amount += dayData.amount 267 | yearData.volume += dayData.volume 268 | if yearData.high < dayData.high: 269 | yearData.high = dayData.high 270 | if yearData.low > dayData.low: 271 | yearData.low = dayData.low 272 | if i == dayDatasSize - 1: 273 | yearDatas.append(yearData) 274 | lyear = year 275 | lmonth = month 276 | return 1 277 | 278 | def mergeLatestData(code, oldDatas, latestData, tickDataCache, dCycle): 279 | """ 合并最新数据 280 | code 代码 281 | oldDatas 老数据数组 282 | latestData 新数据对象 283 | tickDataCache TICK数据缓存对象 284 | dCycle 周期""" 285 | cycle = dCycle 286 | if cycle == 0: 287 | cycle = 1 288 | if latestData.open <= 0 or latestData.volume <= 0 or latestData.amount <= 0: 289 | return 290 | newDate = numToDate(latestData.date) 291 | hourMinute = newDate.hour * 60 + newDate.minute 292 | if hourMinute < 570: 293 | newDate = newDate.replace(hour=9, minute=30, second=0, microsecond=0) 294 | latestData.date = int(newDate.timestamp()) 295 | elif hourMinute < 571: 296 | newDate = newDate.replace(hour=9, minute=31, second=0, microsecond=0) 297 | latestData.date = int(newDate.timestamp()) 298 | elif hourMinute > 900: 299 | newDate = newDate.replace(hour=15, minute=0, second=0, microsecond=0) 300 | latestData.date = int(newDate.timestamp()) 301 | elif hourMinute > 690 and hourMinute < 780: 302 | newDate = newDate.replace(hour=11, minute=30, second=0, microsecond=0) 303 | latestData.date = int(newDate.timestamp()) 304 | 305 | isNextCycle = True 306 | if dCycle == 0: 307 | isNextCycle = False 308 | elif cycle < 1440: 309 | if len(oldDatas) > 0: 310 | newMinutes = latestData.date // 60 311 | lastData = oldDatas[len(oldDatas) - 1] 312 | lastMinutes = lastData.date // 60 313 | isNextCycle = newMinutes - lastMinutes >= cycle 314 | else: 315 | if cycle == 1440: 316 | if len(oldDatas) > 0: 317 | lastDate = numToDate(oldDatas[len(oldDatas) - 1].date) 318 | isNextCycle = getDateNum(newDate.year, newDate.month, newDate.day, 0, 0, 0, 0) != getDateNum(lastDate.year, lastDate.month, lastDate.day, 0, 0, 0, 0) 319 | elif cycle == 10080: 320 | if len(oldDatas) > 0: 321 | firstDate = getDateNum(1970, 1, 5, 0, 0, 0, 0) 322 | lWeeks = ((oldDatas[len(oldDatas) - 1].date - firstDate) // 86400 + 1) // 7 323 | weeks = ((latestData.date - firstDate) // 86400 + 1) // 7 324 | isNextCycle = weeks > lWeeks 325 | elif cycle == 43200: 326 | if len(oldDatas) > 0: 327 | lastDate = numToDate(oldDatas[len(oldDatas) - 1].date) 328 | isNextCycle = newDate.year * 12 + newDate.month != lastDate.year * 12 + lastDate.month 329 | elif cycle == 129600: 330 | if len(oldDatas) > 0: 331 | lastDate = numToDate(oldDatas[len(oldDatas) - 1].date) 332 | isNextCycle = newDate.year * 4 + getSeason(newDate.month) != lastDate.year * 4 + getSeason(lastDate.month) 333 | elif cycle == 259200: 334 | if len(oldDatas) > 0: 335 | lastDate = numToDate(oldDatas[len(oldDatas) - 1].date) 336 | isNextCycle = newDate.year * 2 + (newDate.month // 6) != lastDate.year * 2 + (lastDate.month // 6) 337 | elif cycle == 518400: 338 | if len(oldDatas) > 0: 339 | lastDate = numToDate(oldDatas[len(oldDatas) - 1].date) 340 | isNextCycle = newDate.year != lastDate.year 341 | 342 | if isNextCycle: 343 | newCycleData = SecurityData() 344 | newCycleData.date = latestData.date 345 | newCycleData.open = latestData.close 346 | newCycleData.high = latestData.close 347 | newCycleData.low = latestData.close 348 | newCycleData.close = latestData.close 349 | newCycleData.volume = latestData.volume - tickDataCache.lastVolume 350 | newCycleData.amount = latestData.amount - tickDataCache.lastAmount 351 | oldDatas.append(newCycleData) 352 | else: 353 | if len(oldDatas) > 0: 354 | lastCycleData = oldDatas[len(oldDatas) - 1] 355 | if dCycle == 0: 356 | thisDate = getDateNum(newDate.year, newDate.month, newDate.day, newDate.hour, newDate.minute, 0, 0) 357 | for data in oldDatas: 358 | if data.date == thisDate: 359 | if data.open == 0: 360 | data.open = latestData.open 361 | lastCycleData = data 362 | break 363 | lastCycleData.close = latestData.close 364 | if lastCycleData.high < latestData.close: 365 | lastCycleData.high = latestData.close 366 | if lastCycleData.low > latestData.close: 367 | lastCycleData.low = latestData.close 368 | lastCycleData.amount += latestData.amount - tickDataCache.lastAmount 369 | lastCycleData.volume += latestData.volume - tickDataCache.lastVolume 370 | 371 | tickDataCache.code = code 372 | tickDataCache.lastAmount = latestData.amount 373 | tickDataCache.lastDate = latestData.date 374 | tickDataCache.lastVolume = latestData.volume 375 | 376 | #创建一个存储调整因子的Map 377 | factorsMap = {} 378 | 379 | def fq_price_func(price, factor): 380 | """ 前复权价格计算函数 381 | price 股票价格 382 | factor 调整因子 383 | @returns 调整后的价格""" 384 | cash_bt = factor.f1 385 | bonus_shr = factor.f3 386 | allot_pct = factor.f4 387 | allot_price = factor.f2 388 | return (10.0 * price - cash_bt + allot_pct * allot_price) / (10.0 + allot_pct + bonus_shr) 389 | 390 | def fq_price_func2(price, factor): 391 | """ 后复权价格计算函数 392 | price 股票价格 393 | factor 调整因子 394 | @returns 调整后的价格""" 395 | cash_bt = factor.f1 396 | bonus_shr = factor.f3 397 | allot_pct = factor.f4 398 | allot_price = factor.f2 399 | return (price * (10.0 + allot_pct + bonus_shr) - allot_pct * allot_price + cash_bt) / 10.0 400 | 401 | def convertXdrBeforePrice(kd, trade_date, factor): 402 | """ 转换前复权 403 | code 股票代码 404 | kd 数据 405 | trade_date 交易日期 406 | factor 调整因子数组""" 407 | size = len(factor) 408 | if size > 0: 409 | pos = 0 410 | date = kd.date 411 | if kd.date < factor[len(factor) - 1].dwDate: 412 | for i in range(size): 413 | if trade_date > 0 and trade_date < factor[i].dwDate: 414 | continue 415 | pos = i 416 | if date < factor[i].dwDate: 417 | break 418 | for i in range(pos, size): 419 | if trade_date > 0 and trade_date < factor[i].dwDate: 420 | continue 421 | kd.open = fq_price_func(kd.open, factor[i]) 422 | kd.high = fq_price_func(kd.high, factor[i]) 423 | kd.low = fq_price_func(kd.low, factor[i]) 424 | kd.close = fq_price_func(kd.close, factor[i]) 425 | 426 | def convertXdrAfterPrice(kd, trade_date, factor): 427 | """ 转换后复权 428 | code 股票代码 429 | kd 数据 430 | trade_date 交易日期 431 | factor 调整因子数组""" 432 | size = len(factor) 433 | if size > 0: 434 | date = kd.date 435 | factors = [] 436 | for i in range(size): 437 | if date < factor[i].dwDate: 438 | break 439 | else: 440 | factors.insert(0, factor[i]) 441 | for i in range(len(factors)): 442 | kd.open = fq_price_func2(kd.open, factors[i]) 443 | kd.high = fq_price_func2(kd.high, factors[i]) 444 | kd.low = fq_price_func2(kd.low, factors[i]) 445 | kd.close = fq_price_func2(kd.close, factors[i]) 446 | 447 | def convertXdr(code, rights_offering, datas): 448 | """ 转换XDR 449 | code 股票代码 450 | rights_offering 权利发行类型 451 | datas 数据数组""" 452 | if code in factorsMap: 453 | factor = factorsMap[code] 454 | datas_size = len(datas) 455 | if datas_size > 0: 456 | trade_date = datas[len(datas) - 1].date 457 | for kd in datas: 458 | if rights_offering == 1: 459 | convertXdrBeforePrice(kd, trade_date, factor) 460 | elif rights_offering == 2: 461 | convertXdrAfterPrice(kd, trade_date, factor) 462 | -------------------------------------------------------------------------------- /model/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from einops import rearrange, reduce 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | import torch.nn.functional as F 8 | 9 | 10 | class DifferentiableEntropyFunction(Function): 11 | @staticmethod 12 | def forward(ctx, zq, basis, K, eps): 13 | zb = (zq + 1) / 2 14 | zi = ((zb * basis).sum(-1)).to(torch.int64) 15 | cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype), 16 | 0, 17 | zi.flatten(), 18 | torch.ones_like(zi.flatten()).to(zq.dtype), 19 | 'sum') 20 | prob = (cnt + eps) / (cnt + eps).sum() 21 | H = -(prob * torch.log(prob)).sum() 22 | ctx.save_for_backward(zq, zi, prob) 23 | ctx.K = K 24 | return H 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | zq, zi, prob = ctx.saved_tensors 29 | grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K 30 | reord_grad = grad_array[zi.flatten()].reshape(zi.shape) 31 | grad_input = reord_grad.unsqueeze(-1) * zq 32 | return grad_input, None, None, None, None 33 | 34 | 35 | def codebook_entropy(zq, basis, K, eps=1e-4): 36 | return DifferentiableEntropyFunction.apply(zq, basis, K, eps) 37 | 38 | 39 | class BinarySphericalQuantizer(nn.Module): 40 | def __init__(self, embed_dim, beta, gamma0, gamma, zeta, 41 | input_format='bchw', 42 | soft_entropy=True, group_size=9, 43 | persample_entropy_compute='analytical', 44 | cb_entropy_compute='group', 45 | l2_norm=True, 46 | inv_temperature=1): 47 | """ 48 | Paper link: https://arxiv.org/pdf/2406.07548.pdf 49 | Here we use the official implementation of the BinarySphericalQuantizer. 50 | """ 51 | super().__init__() 52 | self.embed_dim = embed_dim 53 | self.beta = beta # loss weight for commit loss 54 | self.gamma0 = gamma0 # loss weight for entropy penalty 55 | self.gamma = gamma # loss weight for entropy penalty 56 | self.zeta = zeta # loss weight for entire entropy penalty 57 | self.input_format = input_format 58 | assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size" 59 | self.num_groups = self.embed_dim // group_size 60 | self.group_size = group_size 61 | assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'" 62 | assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'" 63 | self.persample_entropy_compute = persample_entropy_compute 64 | self.cb_entropy_compute = cb_entropy_compute 65 | self.l2_norm = l2_norm 66 | self.inv_temperature = inv_temperature 67 | 68 | self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1)) 69 | self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1)) 70 | 71 | self.num_dimensions = 2 ** embed_dim 72 | self.bits_per_index = embed_dim 73 | 74 | # we only need to keep the codebook portion up to the group size 75 | # because we approximate the H loss with this subcode 76 | group_codes = torch.arange(2 ** self.group_size) 77 | group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:] 78 | self.register_buffer('group_codebook', group_codebook, persistent=False) 79 | 80 | self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf 81 | 82 | def quantize(self, z): 83 | assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}" 84 | 85 | zhat = torch.where(z > 0, 86 | torch.tensor(1, dtype=z.dtype, device=z.device), 87 | torch.tensor(-1, dtype=z.dtype, device=z.device)) 88 | return z + (zhat - z).detach() 89 | 90 | def forward(self, z): 91 | # if self.input_format == 'bchw': 92 | # z = rearrange(z, 'b c h w -> b h w c') 93 | zq = self.quantize(z) 94 | 95 | indices = self.codes_to_indexes(zq.detach()) 96 | group_indices = self.codes_to_group_indexes(zq.detach()) 97 | if not self.training: 98 | used_codes = torch.unique(indices, return_counts=False) 99 | else: 100 | used_codes = None 101 | 102 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. 103 | 104 | if self.soft_entropy: 105 | persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z) 106 | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy 107 | else: 108 | zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32) 109 | persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample) 110 | cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim) 111 | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy 112 | 113 | zq = zq * q_scale 114 | 115 | # commit loss 116 | commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1)) 117 | 118 | # if self.input_format == 'bchw': 119 | # zq = rearrange(zq, 'b h w c -> b c h w') 120 | 121 | return ( 122 | zq, 123 | commit_loss + self.zeta * entropy_penalty / self.inv_temperature, 124 | {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices, 125 | "avg_prob": avg_prob} 126 | ) 127 | 128 | def soft_entropy_loss(self, z): 129 | # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size 130 | # the sub-code is the last group_size bits of the full code 131 | group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1) 132 | divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size) 133 | 134 | # we calculate the distance between the divided_z and the codebook for each subgroup 135 | distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book) 136 | prob = (-distance * self.inv_temperature).softmax(dim=-1) 137 | if self.persample_entropy_compute == 'analytical': 138 | if self.l2_norm: 139 | p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature) 140 | else: 141 | p = torch.sigmoid(-4 * z * self.inv_temperature) 142 | prob = torch.stack([p, 1 - p], dim=-1) 143 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() 144 | else: 145 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() 146 | 147 | # macro average of the probability of each subgroup 148 | avg_prob = reduce(prob, '... g d ->g d', 'mean') 149 | codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) 150 | 151 | # the approximation of the entropy is the sum of the entropy of each subgroup 152 | return per_sample_entropy, codebook_entropy.sum(), avg_prob 153 | 154 | def get_hard_per_sample_entropy(self, zb_by_sample): 155 | probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1] 156 | persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8) 157 | persample_entropy = persample_entropy.sum(-1) 158 | return persample_entropy.mean() 159 | 160 | def codes_to_indexes(self, zhat): 161 | """Converts a `code` to an index in the codebook. 162 | Args: 163 | zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} 164 | """ 165 | assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}" 166 | return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64) 167 | 168 | def codes_to_group_indexes(self, zhat): 169 | """Converts a `code` to a list of indexes (in groups) in the codebook. 170 | Args: 171 | zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} 172 | """ 173 | zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size) 174 | return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64) 175 | 176 | def indexes_to_codes(self, indices): 177 | """Inverse of `indexes_to_codes`.""" 178 | indices = indices.unsqueeze(-1) 179 | codes_non_centered = torch.remainder( 180 | torch.floor_divide(indices, self.basis), 2 181 | ) 182 | return codes_non_centered * 2 - 1 183 | 184 | def group_indexes_to_codes(self, group_indices): 185 | """Inverse of `group_indexes_to_codes`.""" 186 | group_indices = group_indices.unsqueeze(-1) 187 | codes_non_centered = torch.remainder( 188 | torch.floor_divide(group_indices, self.group_basis), 2 189 | ) 190 | codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)') 191 | return codes_non_centered * 2 - 1 192 | 193 | def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): 194 | if normalize: 195 | probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True) 196 | else: 197 | probs = count 198 | H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) 199 | return H 200 | 201 | def get_group_codebook_entry(self, group_indices): 202 | z_q = self.group_indexes_to_codes(group_indices) 203 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. 204 | z_q = z_q * q_scale 205 | if self.input_format == 'bchw': 206 | h, w = int(z_q.shape[1] ** 0.5) 207 | assert h * w == z_q.shape[1], 'Invalid sequence length' 208 | z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) 209 | return z_q 210 | 211 | def get_codebook_entry(self, indices): 212 | z_q = self.indexes_to_codes(indices) 213 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. 214 | z_q = z_q * q_scale 215 | if self.input_format == 'bchw': 216 | h, w = int(z_q.shape[1] ** 0.5) 217 | assert h * w == z_q.shape[1], 'Invalid sequence length' 218 | z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) 219 | return z_q 220 | 221 | 222 | class BSQuantizer(nn.Module): 223 | 224 | def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): 225 | super().__init__() 226 | self.codebook_dim = s1_bits + s2_bits 227 | self.s1_bits = s1_bits 228 | self.s2_bits = s2_bits 229 | self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size) 230 | 231 | def bits_to_indices(self, bits): 232 | bits = (bits >= 0).to(torch.long) 233 | indices = 2 ** torch.arange( 234 | 0, 235 | bits.shape[-1], 236 | 1, 237 | dtype=torch.long, 238 | device=bits.device, 239 | ) 240 | return (bits * indices).sum(-1) 241 | 242 | def forward(self, z, half=False): 243 | z = F.normalize(z, dim=-1) 244 | quantized, bsq_loss, metrics = self.bsq(z) 245 | if half: 246 | q_pre = quantized[:, :, :self.s1_bits] 247 | q_post = quantized[:, :, self.s1_bits:] 248 | z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)] 249 | else: 250 | z_indices = self.bits_to_indices(quantized) 251 | return bsq_loss, quantized, z_indices 252 | 253 | 254 | class RMSNorm(torch.nn.Module): 255 | def __init__(self, dim: int, eps: float = 1e-5): 256 | super().__init__() 257 | self.eps = eps 258 | self.weight = nn.Parameter(torch.ones(dim)) 259 | 260 | def _norm(self, x): 261 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 262 | 263 | def forward(self, x): 264 | output = self._norm(x.float()).type_as(x) 265 | return output * self.weight 266 | 267 | 268 | class FeedForward(nn.Module): 269 | def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0): 270 | super().__init__() 271 | 272 | self.w1 = nn.Linear(d_model, ff_dim, bias=False) 273 | self.w3 = nn.Linear(d_model, ff_dim, bias=False) 274 | self.w2 = nn.Linear(ff_dim, d_model, bias=False) 275 | self.ffn_dropout = nn.Dropout(ffn_dropout_p) 276 | 277 | def forward(self, x): 278 | return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) 279 | 280 | 281 | class RotaryPositionalEmbedding(nn.Module): 282 | def __init__(self, dim): 283 | super().__init__() 284 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 285 | self.register_buffer("inv_freq", inv_freq) 286 | self.seq_len_cached = None 287 | self.cos_cached = None 288 | self.sin_cached = None 289 | 290 | def _update_cos_sin_cache(self, x, seq_len): 291 | if seq_len != self.seq_len_cached: 292 | self.seq_len_cached = seq_len 293 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) 294 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 295 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 296 | self.cos_cached = emb.cos()[None, None, :, :] 297 | self.sin_cached = emb.sin()[None, None, :, :] 298 | return self.cos_cached, self.sin_cached 299 | 300 | def forward(self, q, k): 301 | cos, sin = self._update_cos_sin_cache(q, q.shape[-2]) 302 | return ( 303 | (q * cos) + (self._rotate_half(q) * sin), 304 | (k * cos) + (self._rotate_half(k) * sin), 305 | ) 306 | 307 | def _rotate_half(self, x): 308 | x1, x2 = x.chunk(2, dim=-1) 309 | return torch.cat((-x2, x1), dim=-1) 310 | 311 | 312 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: 313 | L, S = query.size(-2), key.size(-2) 314 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 315 | attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device) 316 | 317 | if is_causal: 318 | assert attn_mask is None 319 | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(query.device) 320 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 321 | attn_bias.to(query.dtype) 322 | 323 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 324 | attn_weight += attn_bias 325 | 326 | if attn_mask is not None: 327 | attn_mask_bias = torch.zeros_like(attn_weight) 328 | if attn_mask.dtype == torch.bool: 329 | attn_mask_bias.masked_fill_(attn_mask, float("-inf")) 330 | else: 331 | attn_mask_bias += attn_mask 332 | attn_weight += attn_mask_bias 333 | 334 | attn_weight = torch.softmax(attn_weight, dim=-1) 335 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True) 336 | return attn_weight @ value 337 | 338 | 339 | class MultiHeadAttentionWithRoPE(nn.Module): 340 | def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0): 341 | super().__init__() 342 | self.d_model = d_model 343 | self.n_heads = n_heads 344 | self.head_dim = d_model // n_heads 345 | 346 | self.q_proj = nn.Linear(d_model, d_model) 347 | self.k_proj = nn.Linear(d_model, d_model) 348 | self.v_proj = nn.Linear(d_model, d_model) 349 | self.out_proj = nn.Linear(d_model, d_model) 350 | self.rotary = RotaryPositionalEmbedding(self.head_dim) 351 | self.attn_dropout_p = attn_dropout_p 352 | self.resid_dropout = nn.Dropout(resid_dropout_p) 353 | 354 | def forward(self, x, key_padding_mask=None): 355 | batch_size, seq_len, _ = x.shape 356 | 357 | q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 358 | k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 359 | v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 360 | 361 | q, k = self.rotary(q, k) 362 | 363 | if key_padding_mask is not None: 364 | attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len] 365 | attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len] 366 | else: 367 | attn_mask = None 368 | 369 | attn_output = scaled_dot_product_attention( 370 | q, k, v, 371 | attn_mask=attn_mask, 372 | dropout_p=self.attn_dropout_p, 373 | is_causal=True 374 | ) 375 | 376 | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) 377 | return self.resid_dropout(self.out_proj(attn_output)) 378 | 379 | 380 | class MultiHeadCrossAttentionWithRoPE(nn.Module): 381 | def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0): 382 | super().__init__() 383 | self.d_model = d_model 384 | self.n_heads = n_heads 385 | self.head_dim = d_model // n_heads 386 | 387 | self.q_proj = nn.Linear(d_model, d_model) 388 | self.k_proj = nn.Linear(d_model, d_model) 389 | self.v_proj = nn.Linear(d_model, d_model) 390 | self.out_proj = nn.Linear(d_model, d_model) 391 | self.rotary = RotaryPositionalEmbedding(self.head_dim) 392 | self.attn_dropout_p = attn_dropout_p 393 | self.resid_dropout = nn.Dropout(resid_dropout) 394 | 395 | def forward(self, query, key, value, key_padding_mask=None): 396 | batch_size, q_len, _ = query.shape 397 | _, seq_len, _ = key.shape 398 | 399 | q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2) 400 | k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 401 | v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 402 | 403 | q, k = self.rotary(q, k) 404 | 405 | if key_padding_mask is not None: 406 | attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) 407 | attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1) 408 | else: 409 | attn_mask = None 410 | 411 | is_causal_flag = self.training 412 | 413 | attn_output = scaled_dot_product_attention( 414 | q, k, v, 415 | attn_mask=attn_mask, 416 | dropout_p=self.attn_dropout_p, 417 | is_causal=is_causal_flag 418 | ) 419 | 420 | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model) 421 | return self.resid_dropout(self.out_proj(attn_output)) 422 | 423 | 424 | class HierarchicalEmbedding(nn.Module): 425 | def __init__(self, s1_bits, s2_bits, d_model=256): 426 | super().__init__() 427 | self.s1_bits = s1_bits 428 | self.s2_bits = s2_bits 429 | 430 | vocab_s1 = 2 ** s1_bits 431 | vocab_s2 = 2 ** s2_bits 432 | 433 | self.emb_s1 = nn.Embedding(vocab_s1, d_model) 434 | self.emb_s2 = nn.Embedding(vocab_s2, d_model) 435 | self.d_model = d_model 436 | self.fusion_proj = nn.Linear(d_model * 2, d_model) 437 | 438 | nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5) 439 | nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5) 440 | 441 | def forward(self, token_ids): 442 | """Inputs: 443 | token_ids: [batch_size, seq_len] token ID 444 | Output: [batch_size, seq_len, d_model] 445 | """ 446 | if isinstance(token_ids, tuple) or isinstance(token_ids, list): 447 | s1_ids, s2_ids = token_ids 448 | else: 449 | s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits) 450 | s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model) 451 | s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model) 452 | return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1)) 453 | 454 | 455 | class DependencyAwareLayer(nn.Module): 456 | def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0): 457 | super().__init__() 458 | self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout) 459 | self.norm = RMSNorm(d_model) 460 | 461 | def forward(self, hidden_states, sibling_embed, key_padding_mask=None): 462 | """hidden_states: [batch, seq_len, d_model] 463 | sibling_embed: Embedding from another subtoken 464 | """ 465 | attn_out = self.cross_attn( 466 | query=sibling_embed, 467 | key=hidden_states, 468 | value=hidden_states, 469 | key_padding_mask=key_padding_mask 470 | ) 471 | return self.norm(hidden_states + attn_out) 472 | 473 | 474 | class TransformerBlock(nn.Module): 475 | def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0): 476 | super().__init__() 477 | self.norm1 = RMSNorm(d_model) 478 | self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p) 479 | self.norm2 = RMSNorm(d_model) 480 | self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p) 481 | 482 | def forward(self, x, key_padding_mask=None): 483 | residual = x 484 | x = self.norm1(x) 485 | attn_out = self.self_attn(x, key_padding_mask=key_padding_mask) 486 | x = residual + attn_out 487 | 488 | residual = x 489 | x = self.norm2(x) 490 | ffn_out = self.ffn(x) 491 | x = residual + ffn_out 492 | return x 493 | 494 | 495 | class DualHead(nn.Module): 496 | def __init__(self, s1_bits, s2_bits, d_model): 497 | super().__init__() 498 | self.vocab_s1 = 2 ** s1_bits 499 | self.vocab_s2 = 2 ** s2_bits 500 | self.proj_s1 = nn.Linear(d_model, self.vocab_s1) 501 | self.proj_s2 = nn.Linear(d_model, self.vocab_s2) 502 | 503 | def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None): 504 | if padding_mask is not None: 505 | valid_mask = (padding_mask == 0) 506 | s1_logits = s1_logits[valid_mask] 507 | s2_logits = s2_logits[valid_mask] 508 | s1_targets = s1_targets[valid_mask] 509 | s2_targets = s2_targets[valid_mask] 510 | ce_s1 = F.cross_entropy(s1_logits, s1_targets) 511 | ce_s2 = F.cross_entropy(s2_logits, s2_targets) 512 | else: 513 | ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1)) 514 | ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1)) 515 | ce_loss = (ce_s1 + ce_s2) / 2 516 | return ce_loss, ce_s1, ce_s2 517 | 518 | def forward(self, x): 519 | return self.proj_s1(x) 520 | 521 | def cond_forward(self, x2): 522 | return self.proj_s2(x2) 523 | 524 | 525 | class FixedEmbedding(nn.Module): 526 | def __init__(self, c_in, d_model): 527 | super(FixedEmbedding, self).__init__() 528 | 529 | w = torch.zeros(c_in, d_model).float() 530 | w.require_grad = False 531 | 532 | position = torch.arange(0, c_in).float().unsqueeze(1) 533 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 534 | 535 | w[:, 0::2] = torch.sin(position * div_term) 536 | w[:, 1::2] = torch.cos(position * div_term) 537 | 538 | self.emb = nn.Embedding(c_in, d_model) 539 | self.emb.weight = nn.Parameter(w, requires_grad=False) 540 | 541 | def forward(self, x): 542 | return self.emb(x).detach() 543 | 544 | 545 | class TemporalEmbedding(nn.Module): 546 | def __init__(self, d_model, learn_pe): 547 | super(TemporalEmbedding, self).__init__() 548 | 549 | minute_size = 60 550 | hour_size = 24 551 | weekday_size = 7 552 | day_size = 32 553 | month_size = 13 554 | 555 | Embed = FixedEmbedding if not learn_pe else nn.Embedding 556 | self.minute_embed = Embed(minute_size, d_model) 557 | self.hour_embed = Embed(hour_size, d_model) 558 | self.weekday_embed = Embed(weekday_size, d_model) 559 | self.day_embed = Embed(day_size, d_model) 560 | self.month_embed = Embed(month_size, d_model) 561 | 562 | def forward(self, x): 563 | x = x.long() 564 | 565 | minute_x = self.minute_embed(x[:, :, 0]) 566 | hour_x = self.hour_embed(x[:, :, 1]) 567 | weekday_x = self.weekday_embed(x[:, :, 2]) 568 | day_x = self.day_embed(x[:, :, 3]) 569 | month_x = self.month_embed(x[:, :, 4]) 570 | 571 | return hour_x + weekday_x + day_x + month_x + minute_x 572 | 573 | 574 | 575 | 576 | 577 | 578 | -------------------------------------------------------------------------------- /facecat/model/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from einops import rearrange, reduce 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | import torch.nn.functional as F 8 | 9 | 10 | class DifferentiableEntropyFunction(Function): 11 | @staticmethod 12 | def forward(ctx, zq, basis, K, eps): 13 | zb = (zq + 1) / 2 14 | zi = ((zb * basis).sum(-1)).to(torch.int64) 15 | cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype), 16 | 0, 17 | zi.flatten(), 18 | torch.ones_like(zi.flatten()).to(zq.dtype), 19 | 'sum') 20 | prob = (cnt + eps) / (cnt + eps).sum() 21 | H = -(prob * torch.log(prob)).sum() 22 | ctx.save_for_backward(zq, zi, prob) 23 | ctx.K = K 24 | return H 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | zq, zi, prob = ctx.saved_tensors 29 | grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K 30 | reord_grad = grad_array[zi.flatten()].reshape(zi.shape) 31 | grad_input = reord_grad.unsqueeze(-1) * zq 32 | return grad_input, None, None, None, None 33 | 34 | 35 | def codebook_entropy(zq, basis, K, eps=1e-4): 36 | return DifferentiableEntropyFunction.apply(zq, basis, K, eps) 37 | 38 | 39 | class BinarySphericalQuantizer(nn.Module): 40 | def __init__(self, embed_dim, beta, gamma0, gamma, zeta, 41 | input_format='bchw', 42 | soft_entropy=True, group_size=9, 43 | persample_entropy_compute='analytical', 44 | cb_entropy_compute='group', 45 | l2_norm=True, 46 | inv_temperature=1): 47 | """ 48 | Paper link: https://arxiv.org/pdf/2406.07548.pdf 49 | Here we use the official implementation of the BinarySphericalQuantizer. 50 | """ 51 | super().__init__() 52 | self.embed_dim = embed_dim 53 | self.beta = beta # loss weight for commit loss 54 | self.gamma0 = gamma0 # loss weight for entropy penalty 55 | self.gamma = gamma # loss weight for entropy penalty 56 | self.zeta = zeta # loss weight for entire entropy penalty 57 | self.input_format = input_format 58 | assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size" 59 | self.num_groups = self.embed_dim // group_size 60 | self.group_size = group_size 61 | assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'" 62 | assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'" 63 | self.persample_entropy_compute = persample_entropy_compute 64 | self.cb_entropy_compute = cb_entropy_compute 65 | self.l2_norm = l2_norm 66 | self.inv_temperature = inv_temperature 67 | 68 | self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1)) 69 | self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1)) 70 | 71 | self.num_dimensions = 2 ** embed_dim 72 | self.bits_per_index = embed_dim 73 | 74 | # we only need to keep the codebook portion up to the group size 75 | # because we approximate the H loss with this subcode 76 | group_codes = torch.arange(2 ** self.group_size) 77 | group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:] 78 | self.register_buffer('group_codebook', group_codebook, persistent=False) 79 | 80 | self.soft_entropy = soft_entropy # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf 81 | 82 | def quantize(self, z): 83 | assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}" 84 | 85 | zhat = torch.where(z > 0, 86 | torch.tensor(1, dtype=z.dtype, device=z.device), 87 | torch.tensor(-1, dtype=z.dtype, device=z.device)) 88 | return z + (zhat - z).detach() 89 | 90 | def forward(self, z): 91 | # if self.input_format == 'bchw': 92 | # z = rearrange(z, 'b c h w -> b h w c') 93 | zq = self.quantize(z) 94 | 95 | indices = self.codes_to_indexes(zq.detach()) 96 | group_indices = self.codes_to_group_indexes(zq.detach()) 97 | if not self.training: 98 | used_codes = torch.unique(indices, return_counts=False) 99 | else: 100 | used_codes = None 101 | 102 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. 103 | 104 | if self.soft_entropy: 105 | persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z) 106 | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy 107 | else: 108 | zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32) 109 | persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample) 110 | cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim) 111 | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy 112 | 113 | zq = zq * q_scale 114 | 115 | # commit loss 116 | commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1)) 117 | 118 | # if self.input_format == 'bchw': 119 | # zq = rearrange(zq, 'b h w c -> b c h w') 120 | 121 | return ( 122 | zq, 123 | commit_loss + self.zeta * entropy_penalty / self.inv_temperature, 124 | {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices, 125 | "avg_prob": avg_prob} 126 | ) 127 | 128 | def soft_entropy_loss(self, z): 129 | # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size 130 | # the sub-code is the last group_size bits of the full code 131 | group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1) 132 | divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size) 133 | 134 | # we calculate the distance between the divided_z and the codebook for each subgroup 135 | distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book) 136 | prob = (-distance * self.inv_temperature).softmax(dim=-1) 137 | if self.persample_entropy_compute == 'analytical': 138 | if self.l2_norm: 139 | p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature) 140 | else: 141 | p = torch.sigmoid(-4 * z * self.inv_temperature) 142 | prob = torch.stack([p, 1 - p], dim=-1) 143 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() 144 | else: 145 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() 146 | 147 | # macro average of the probability of each subgroup 148 | avg_prob = reduce(prob, '... g d ->g d', 'mean') 149 | codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) 150 | 151 | # the approximation of the entropy is the sum of the entropy of each subgroup 152 | return per_sample_entropy, codebook_entropy.sum(), avg_prob 153 | 154 | def get_hard_per_sample_entropy(self, zb_by_sample): 155 | probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1] 156 | persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8) 157 | persample_entropy = persample_entropy.sum(-1) 158 | return persample_entropy.mean() 159 | 160 | def codes_to_indexes(self, zhat): 161 | """Converts a `code` to an index in the codebook. 162 | Args: 163 | zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} 164 | """ 165 | assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}" 166 | return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64) 167 | 168 | def codes_to_group_indexes(self, zhat): 169 | """Converts a `code` to a list of indexes (in groups) in the codebook. 170 | Args: 171 | zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1} 172 | """ 173 | zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size) 174 | return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64) 175 | 176 | def indexes_to_codes(self, indices): 177 | """Inverse of `indexes_to_codes`.""" 178 | indices = indices.unsqueeze(-1) 179 | codes_non_centered = torch.remainder( 180 | torch.floor_divide(indices, self.basis), 2 181 | ) 182 | return codes_non_centered * 2 - 1 183 | 184 | def group_indexes_to_codes(self, group_indices): 185 | """Inverse of `group_indexes_to_codes`.""" 186 | group_indices = group_indices.unsqueeze(-1) 187 | codes_non_centered = torch.remainder( 188 | torch.floor_divide(group_indices, self.group_basis), 2 189 | ) 190 | codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)') 191 | return codes_non_centered * 2 - 1 192 | 193 | def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): 194 | if normalize: 195 | probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True) 196 | else: 197 | probs = count 198 | H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) 199 | return H 200 | 201 | def get_group_codebook_entry(self, group_indices): 202 | z_q = self.group_indexes_to_codes(group_indices) 203 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. 204 | z_q = z_q * q_scale 205 | if self.input_format == 'bchw': 206 | h, w = int(z_q.shape[1] ** 0.5) 207 | assert h * w == z_q.shape[1], 'Invalid sequence length' 208 | z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) 209 | return z_q 210 | 211 | def get_codebook_entry(self, indices): 212 | z_q = self.indexes_to_codes(indices) 213 | q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1. 214 | z_q = z_q * q_scale 215 | if self.input_format == 'bchw': 216 | h, w = int(z_q.shape[1] ** 0.5) 217 | assert h * w == z_q.shape[1], 'Invalid sequence length' 218 | z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h) 219 | return z_q 220 | 221 | 222 | class BSQuantizer(nn.Module): 223 | 224 | def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): 225 | super().__init__() 226 | self.codebook_dim = s1_bits + s2_bits 227 | self.s1_bits = s1_bits 228 | self.s2_bits = s2_bits 229 | self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size) 230 | 231 | def bits_to_indices(self, bits): 232 | bits = (bits >= 0).to(torch.long) 233 | indices = 2 ** torch.arange( 234 | 0, 235 | bits.shape[-1], 236 | 1, 237 | dtype=torch.long, 238 | device=bits.device, 239 | ) 240 | return (bits * indices).sum(-1) 241 | 242 | def forward(self, z, half=False): 243 | z = F.normalize(z, dim=-1) 244 | quantized, bsq_loss, metrics = self.bsq(z) 245 | if half: 246 | q_pre = quantized[:, :, :self.s1_bits] 247 | q_post = quantized[:, :, self.s1_bits:] 248 | z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)] 249 | else: 250 | z_indices = self.bits_to_indices(quantized) 251 | return bsq_loss, quantized, z_indices 252 | 253 | 254 | class RMSNorm(torch.nn.Module): 255 | def __init__(self, dim: int, eps: float = 1e-5): 256 | super().__init__() 257 | self.eps = eps 258 | self.weight = nn.Parameter(torch.ones(dim)) 259 | 260 | def _norm(self, x): 261 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 262 | 263 | def forward(self, x): 264 | output = self._norm(x.float()).type_as(x) 265 | return output * self.weight 266 | 267 | 268 | class FeedForward(nn.Module): 269 | def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0): 270 | super().__init__() 271 | 272 | self.w1 = nn.Linear(d_model, ff_dim, bias=False) 273 | self.w3 = nn.Linear(d_model, ff_dim, bias=False) 274 | self.w2 = nn.Linear(ff_dim, d_model, bias=False) 275 | self.ffn_dropout = nn.Dropout(ffn_dropout_p) 276 | 277 | def forward(self, x): 278 | return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) 279 | 280 | 281 | class RotaryPositionalEmbedding(nn.Module): 282 | def __init__(self, dim): 283 | super().__init__() 284 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 285 | self.register_buffer("inv_freq", inv_freq) 286 | self.seq_len_cached = None 287 | self.cos_cached = None 288 | self.sin_cached = None 289 | 290 | def _update_cos_sin_cache(self, x, seq_len): 291 | if seq_len != self.seq_len_cached: 292 | self.seq_len_cached = seq_len 293 | t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) 294 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 295 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 296 | self.cos_cached = emb.cos()[None, None, :, :] 297 | self.sin_cached = emb.sin()[None, None, :, :] 298 | return self.cos_cached, self.sin_cached 299 | 300 | def forward(self, q, k): 301 | cos, sin = self._update_cos_sin_cache(q, q.shape[-2]) 302 | return ( 303 | (q * cos) + (self._rotate_half(q) * sin), 304 | (k * cos) + (self._rotate_half(k) * sin), 305 | ) 306 | 307 | def _rotate_half(self, x): 308 | x1, x2 = x.chunk(2, dim=-1) 309 | return torch.cat((-x2, x1), dim=-1) 310 | 311 | 312 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: 313 | L, S = query.size(-2), key.size(-2) 314 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 315 | attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device) 316 | 317 | if is_causal: 318 | assert attn_mask is None 319 | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(query.device) 320 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 321 | attn_bias.to(query.dtype) 322 | 323 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 324 | attn_weight += attn_bias 325 | 326 | if attn_mask is not None: 327 | attn_mask_bias = torch.zeros_like(attn_weight) 328 | if attn_mask.dtype == torch.bool: 329 | attn_mask_bias.masked_fill_(attn_mask, float("-inf")) 330 | else: 331 | attn_mask_bias += attn_mask 332 | attn_weight += attn_mask_bias 333 | 334 | attn_weight = torch.softmax(attn_weight, dim=-1) 335 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True) 336 | return attn_weight @ value 337 | 338 | 339 | class MultiHeadAttentionWithRoPE(nn.Module): 340 | def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0): 341 | super().__init__() 342 | self.d_model = d_model 343 | self.n_heads = n_heads 344 | self.head_dim = d_model // n_heads 345 | 346 | self.q_proj = nn.Linear(d_model, d_model) 347 | self.k_proj = nn.Linear(d_model, d_model) 348 | self.v_proj = nn.Linear(d_model, d_model) 349 | self.out_proj = nn.Linear(d_model, d_model) 350 | self.rotary = RotaryPositionalEmbedding(self.head_dim) 351 | self.attn_dropout_p = attn_dropout_p 352 | self.resid_dropout = nn.Dropout(resid_dropout_p) 353 | 354 | def forward(self, x, key_padding_mask=None): 355 | batch_size, seq_len, _ = x.shape 356 | 357 | q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 358 | k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 359 | v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 360 | 361 | q, k = self.rotary(q, k) 362 | 363 | if key_padding_mask is not None: 364 | attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # [batch, 1, 1, seq_len] 365 | attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1) # [batch, n_heads, q_len, k_len] 366 | else: 367 | attn_mask = None 368 | 369 | attn_output = scaled_dot_product_attention( 370 | q, k, v, 371 | attn_mask=attn_mask, 372 | dropout_p=self.attn_dropout_p, 373 | is_causal=True 374 | ) 375 | 376 | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) 377 | return self.resid_dropout(self.out_proj(attn_output)) 378 | 379 | 380 | class MultiHeadCrossAttentionWithRoPE(nn.Module): 381 | def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0): 382 | super().__init__() 383 | self.d_model = d_model 384 | self.n_heads = n_heads 385 | self.head_dim = d_model // n_heads 386 | 387 | self.q_proj = nn.Linear(d_model, d_model) 388 | self.k_proj = nn.Linear(d_model, d_model) 389 | self.v_proj = nn.Linear(d_model, d_model) 390 | self.out_proj = nn.Linear(d_model, d_model) 391 | self.rotary = RotaryPositionalEmbedding(self.head_dim) 392 | self.attn_dropout_p = attn_dropout_p 393 | self.resid_dropout = nn.Dropout(resid_dropout) 394 | 395 | def forward(self, query, key, value, key_padding_mask=None): 396 | batch_size, q_len, _ = query.shape 397 | _, seq_len, _ = key.shape 398 | 399 | q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2) 400 | k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 401 | v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) 402 | 403 | q, k = self.rotary(q, k) 404 | 405 | if key_padding_mask is not None: 406 | attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) 407 | attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1) 408 | else: 409 | attn_mask = None 410 | 411 | is_causal_flag = self.training 412 | 413 | attn_output = scaled_dot_product_attention( 414 | q, k, v, 415 | attn_mask=attn_mask, 416 | dropout_p=self.attn_dropout_p, 417 | is_causal=is_causal_flag 418 | ) 419 | 420 | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model) 421 | return self.resid_dropout(self.out_proj(attn_output)) 422 | 423 | 424 | class HierarchicalEmbedding(nn.Module): 425 | def __init__(self, s1_bits, s2_bits, d_model=256): 426 | super().__init__() 427 | self.s1_bits = s1_bits 428 | self.s2_bits = s2_bits 429 | 430 | vocab_s1 = 2 ** s1_bits 431 | vocab_s2 = 2 ** s2_bits 432 | 433 | self.emb_s1 = nn.Embedding(vocab_s1, d_model) 434 | self.emb_s2 = nn.Embedding(vocab_s2, d_model) 435 | self.d_model = d_model 436 | self.fusion_proj = nn.Linear(d_model * 2, d_model) 437 | 438 | nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5) 439 | nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5) 440 | 441 | def forward(self, token_ids): 442 | """Inputs: 443 | token_ids: [batch_size, seq_len] token ID 444 | Output: [batch_size, seq_len, d_model] 445 | """ 446 | if isinstance(token_ids, tuple) or isinstance(token_ids, list): 447 | s1_ids, s2_ids = token_ids 448 | else: 449 | s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits) 450 | s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model) 451 | s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model) 452 | return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1)) 453 | 454 | 455 | class DependencyAwareLayer(nn.Module): 456 | def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0): 457 | super().__init__() 458 | self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout) 459 | self.norm = RMSNorm(d_model) 460 | 461 | def forward(self, hidden_states, sibling_embed, key_padding_mask=None): 462 | """hidden_states: [batch, seq_len, d_model] 463 | sibling_embed: Embedding from another subtoken 464 | """ 465 | attn_out = self.cross_attn( 466 | query=sibling_embed, 467 | key=hidden_states, 468 | value=hidden_states, 469 | key_padding_mask=key_padding_mask 470 | ) 471 | return self.norm(hidden_states + attn_out) 472 | 473 | 474 | class TransformerBlock(nn.Module): 475 | def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0): 476 | super().__init__() 477 | self.norm1 = RMSNorm(d_model) 478 | self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p) 479 | self.norm2 = RMSNorm(d_model) 480 | self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p) 481 | 482 | def forward(self, x, key_padding_mask=None): 483 | residual = x 484 | x = self.norm1(x) 485 | attn_out = self.self_attn(x, key_padding_mask=key_padding_mask) 486 | x = residual + attn_out 487 | 488 | residual = x 489 | x = self.norm2(x) 490 | ffn_out = self.ffn(x) 491 | x = residual + ffn_out 492 | return x 493 | 494 | 495 | class DualHead(nn.Module): 496 | def __init__(self, s1_bits, s2_bits, d_model): 497 | super().__init__() 498 | self.vocab_s1 = 2 ** s1_bits 499 | self.vocab_s2 = 2 ** s2_bits 500 | self.proj_s1 = nn.Linear(d_model, self.vocab_s1) 501 | self.proj_s2 = nn.Linear(d_model, self.vocab_s2) 502 | 503 | def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None): 504 | if padding_mask is not None: 505 | valid_mask = (padding_mask == 0) 506 | s1_logits = s1_logits[valid_mask] 507 | s2_logits = s2_logits[valid_mask] 508 | s1_targets = s1_targets[valid_mask] 509 | s2_targets = s2_targets[valid_mask] 510 | ce_s1 = F.cross_entropy(s1_logits, s1_targets) 511 | ce_s2 = F.cross_entropy(s2_logits, s2_targets) 512 | else: 513 | ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1)) 514 | ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1)) 515 | ce_loss = (ce_s1 + ce_s2) / 2 516 | return ce_loss, ce_s1, ce_s2 517 | 518 | def forward(self, x): 519 | return self.proj_s1(x) 520 | 521 | def cond_forward(self, x2): 522 | return self.proj_s2(x2) 523 | 524 | 525 | class FixedEmbedding(nn.Module): 526 | def __init__(self, c_in, d_model): 527 | super(FixedEmbedding, self).__init__() 528 | 529 | w = torch.zeros(c_in, d_model).float() 530 | w.require_grad = False 531 | 532 | position = torch.arange(0, c_in).float().unsqueeze(1) 533 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 534 | 535 | w[:, 0::2] = torch.sin(position * div_term) 536 | w[:, 1::2] = torch.cos(position * div_term) 537 | 538 | self.emb = nn.Embedding(c_in, d_model) 539 | self.emb.weight = nn.Parameter(w, requires_grad=False) 540 | 541 | def forward(self, x): 542 | return self.emb(x).detach() 543 | 544 | 545 | class TemporalEmbedding(nn.Module): 546 | def __init__(self, d_model, learn_pe): 547 | super(TemporalEmbedding, self).__init__() 548 | 549 | minute_size = 60 550 | hour_size = 24 551 | weekday_size = 7 552 | day_size = 32 553 | month_size = 13 554 | 555 | Embed = FixedEmbedding if not learn_pe else nn.Embedding 556 | self.minute_embed = Embed(minute_size, d_model) 557 | self.hour_embed = Embed(hour_size, d_model) 558 | self.weekday_embed = Embed(weekday_size, d_model) 559 | self.day_embed = Embed(day_size, d_model) 560 | self.month_embed = Embed(month_size, d_model) 561 | 562 | def forward(self, x): 563 | x = x.long() 564 | 565 | minute_x = self.minute_embed(x[:, :, 0]) 566 | hour_x = self.hour_embed(x[:, :, 1]) 567 | weekday_x = self.weekday_embed(x[:, :, 2]) 568 | day_x = self.day_embed(x[:, :, 3]) 569 | month_x = self.month_embed(x[:, :, 4]) 570 | 571 | return hour_x + weekday_x + day_x + month_x + minute_x 572 | 573 | 574 | 575 | 576 | 577 | 578 | -------------------------------------------------------------------------------- /model/kronos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | from huggingface_hub import PyTorchModelHubMixin 5 | import sys 6 | 7 | from tqdm import trange 8 | 9 | sys.path.append("../") 10 | from model.module import * 11 | 12 | 13 | class KronosTokenizer(nn.Module, PyTorchModelHubMixin): 14 | """ 15 | KronosTokenizer module for tokenizing input data using a hybrid quantization approach. 16 | 17 | This tokenizer utilizes a combination of encoder and decoder Transformer blocks 18 | along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data. 19 | 20 | Args: 21 | d_in (int): Input dimension. 22 | d_model (int): Model dimension. 23 | n_heads (int): Number of attention heads. 24 | ff_dim (int): Feed-forward dimension. 25 | n_enc_layers (int): Number of encoder layers. 26 | n_dec_layers (int): Number of decoder layers. 27 | ffn_dropout_p (float): Dropout probability for feed-forward networks. 28 | attn_dropout_p (float): Dropout probability for attention mechanisms. 29 | resid_dropout_p (float): Dropout probability for residual connections. 30 | s1_bits (int): Number of bits for the pre token in BSQuantizer. 31 | s2_bits (int): Number of bits for the post token in BSQuantizer. 32 | beta (float): Beta parameter for BSQuantizer. 33 | gamma0 (float): Gamma0 parameter for BSQuantizer. 34 | gamma (float): Gamma parameter for BSQuantizer. 35 | zeta (float): Zeta parameter for BSQuantizer. 36 | group_size (int): Group size parameter for BSQuantizer. 37 | 38 | """ 39 | 40 | def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, ffn_dropout_p, attn_dropout_p, resid_dropout_p, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): 41 | 42 | super().__init__() 43 | self.d_in = d_in 44 | self.d_model = d_model 45 | self.n_heads = n_heads 46 | self.ff_dim = ff_dim 47 | self.enc_layers = n_enc_layers 48 | self.dec_layers = n_dec_layers 49 | self.ffn_dropout_p = ffn_dropout_p 50 | self.attn_dropout_p = attn_dropout_p 51 | self.resid_dropout_p = resid_dropout_p 52 | 53 | self.s1_bits = s1_bits 54 | self.s2_bits = s2_bits 55 | self.codebook_dim = s1_bits + s2_bits # Total dimension of the codebook after quantization 56 | self.embed = nn.Linear(self.d_in, self.d_model) 57 | self.head = nn.Linear(self.d_model, self.d_in) 58 | 59 | # Encoder Transformer Blocks 60 | self.encoder = nn.ModuleList([ 61 | TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) 62 | for _ in range(self.enc_layers - 1) 63 | ]) 64 | # Decoder Transformer Blocks 65 | self.decoder = nn.ModuleList([ 66 | TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) 67 | for _ in range(self.dec_layers - 1) 68 | ]) 69 | self.quant_embed = nn.Linear(in_features=self.d_model, out_features=self.codebook_dim) # Linear layer before quantization 70 | self.post_quant_embed_pre = nn.Linear(in_features=self.s1_bits, out_features=self.d_model) # Linear layer after quantization (pre part - s1 bits) 71 | self.post_quant_embed = nn.Linear(in_features=self.codebook_dim, out_features=self.d_model) # Linear layer after quantization (full codebook) 72 | self.tokenizer = BSQuantizer(self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size) # BSQuantizer module 73 | 74 | def forward(self, x): 75 | """ 76 | Forward pass of the KronosTokenizer. 77 | 78 | Args: 79 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). 80 | 81 | Returns: 82 | tuple: A tuple containing: 83 | - tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively, 84 | both of shape (batch_size, seq_len, d_in). 85 | - torch.Tensor: bsq_loss - Loss from the BSQuantizer. 86 | - torch.Tensor: quantized - Quantized representation from BSQuantizer. 87 | - torch.Tensor: z_indices - Indices from the BSQuantizer. 88 | """ 89 | z = self.embed(x) 90 | 91 | for layer in self.encoder: 92 | z = layer(z) 93 | 94 | z = self.quant_embed(z) # (B, T, codebook) 95 | 96 | bsq_loss, quantized, z_indices = self.tokenizer(z) 97 | 98 | quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits) 99 | z_pre = self.post_quant_embed_pre(quantized_pre) 100 | 101 | z = self.post_quant_embed(quantized) 102 | 103 | # Decoder layers (for pre part - s1 bits) 104 | for layer in self.decoder: 105 | z_pre = layer(z_pre) 106 | z_pre = self.head(z_pre) 107 | 108 | # Decoder layers (for full codebook) 109 | for layer in self.decoder: 110 | z = layer(z) 111 | z = self.head(z) 112 | 113 | return (z_pre, z), bsq_loss, quantized, z_indices 114 | 115 | def indices_to_bits(self, x, half=False): 116 | """ 117 | Converts indices to bit representations and scales them. 118 | 119 | Args: 120 | x (torch.Tensor): Indices tensor. 121 | half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False. 122 | 123 | Returns: 124 | torch.Tensor: Bit representation tensor. 125 | """ 126 | if half: 127 | x1 = x[0] # Assuming x is a tuple of indices if half is True 128 | x2 = x[1] 129 | mask = 2 ** torch.arange(self.codebook_dim//2, device=x1.device, dtype=torch.long) # Create a mask for bit extraction 130 | x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half 131 | x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half 132 | x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations 133 | else: 134 | mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # Create a mask for bit extraction 135 | x = (x.unsqueeze(-1) & mask) != 0 # Extract bits 136 | 137 | x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1) 138 | q_scale = 1. / (self.codebook_dim ** 0.5) # Scaling factor 139 | x = x * q_scale 140 | return x 141 | 142 | def encode(self, x, half=False): 143 | """ 144 | Encodes the input data into quantized indices. 145 | 146 | Args: 147 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). 148 | half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False. 149 | 150 | Returns: 151 | torch.Tensor: Quantized indices from BSQuantizer. 152 | """ 153 | z = self.embed(x) 154 | for layer in self.encoder: 155 | z = layer(z) 156 | z = self.quant_embed(z) 157 | 158 | bsq_loss, quantized, z_indices = self.tokenizer(z, half) 159 | return z_indices 160 | 161 | def decode(self, x, half=False): 162 | """ 163 | Decodes quantized indices back to the input data space. 164 | 165 | Args: 166 | x (torch.Tensor): Quantized indices tensor. 167 | half (bool, optional): Whether the indices were generated with half quantization. Defaults to False. 168 | 169 | Returns: 170 | torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in). 171 | """ 172 | quantized = self.indices_to_bits(x, half) 173 | z = self.post_quant_embed(quantized) 174 | for layer in self.decoder: 175 | z = layer(z) 176 | z = self.head(z) 177 | return z 178 | 179 | 180 | class Kronos(nn.Module, PyTorchModelHubMixin): 181 | """ 182 | Kronos Model. 183 | 184 | Args: 185 | s1_bits (int): Number of bits for pre tokens. 186 | s2_bits (int): Number of bits for post tokens. 187 | n_layers (int): Number of Transformer blocks. 188 | d_model (int): Dimension of the model's embeddings and hidden states. 189 | n_heads (int): Number of attention heads in the MultiheadAttention layers. 190 | ff_dim (int): Dimension of the feedforward network in the Transformer blocks. 191 | ffn_dropout_p (float): Dropout probability for the feedforward network. 192 | attn_dropout_p (float): Dropout probability for the attention layers. 193 | resid_dropout_p (float): Dropout probability for residual connections. 194 | token_dropout_p (float): Dropout probability for token embeddings. 195 | learn_te (bool): Whether to use learnable temporal embeddings. 196 | """ 197 | 198 | def __init__(self, s1_bits, s2_bits, n_layers, d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p, token_dropout_p, learn_te): 199 | super().__init__() 200 | self.s1_bits = s1_bits 201 | self.s2_bits = s2_bits 202 | self.n_layers = n_layers 203 | self.d_model = d_model 204 | self.n_heads = n_heads 205 | self.learn_te = learn_te 206 | self.ff_dim = ff_dim 207 | self.ffn_dropout_p = ffn_dropout_p 208 | self.attn_dropout_p = attn_dropout_p 209 | self.resid_dropout_p = resid_dropout_p 210 | self.token_dropout_p = token_dropout_p 211 | 212 | self.s1_vocab_size = 2 ** self.s1_bits 213 | self.token_drop = nn.Dropout(self.token_dropout_p) 214 | self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model) 215 | self.time_emb = TemporalEmbedding(self.d_model, self.learn_te) 216 | self.transformer = nn.ModuleList([ 217 | TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) 218 | for _ in range(self.n_layers) 219 | ]) 220 | self.norm = RMSNorm(self.d_model) 221 | self.dep_layer = DependencyAwareLayer(self.d_model) 222 | self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model) 223 | self.apply(self._init_weights) 224 | 225 | def _init_weights(self, module): 226 | 227 | if isinstance(module, nn.Linear): 228 | nn.init.xavier_normal_(module.weight) 229 | if module.bias is not None: 230 | nn.init.zeros_(module.bias) 231 | elif isinstance(module, nn.Embedding): 232 | nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model ** -0.5) 233 | elif isinstance(module, nn.LayerNorm): 234 | nn.init.ones_(module.weight) 235 | nn.init.zeros_(module.bias) 236 | elif isinstance(module, RMSNorm): 237 | nn.init.ones_(module.weight) 238 | 239 | def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None): 240 | """ 241 | Args: 242 | s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] 243 | s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] 244 | stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. 245 | padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. 246 | use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False. 247 | s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None. 248 | 249 | Returns: 250 | Tuple[torch.Tensor, torch.Tensor]: 251 | - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] 252 | - s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size] 253 | """ 254 | x = self.embedding([s1_ids, s2_ids]) 255 | if stamp is not None: 256 | time_embedding = self.time_emb(stamp) 257 | x = x + time_embedding 258 | x = self.token_drop(x) 259 | 260 | for layer in self.transformer: 261 | x = layer(x, key_padding_mask=padding_mask) 262 | 263 | x = self.norm(x) 264 | 265 | s1_logits = self.head(x) 266 | 267 | if use_teacher_forcing: 268 | sibling_embed = self.embedding.emb_s1(s1_targets) 269 | else: 270 | s1_probs = F.softmax(s1_logits.detach(), dim=-1) 271 | sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape) 272 | sibling_embed = self.embedding.emb_s1(sample_s1_ids) 273 | 274 | x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings 275 | s2_logits = self.head.cond_forward(x2) 276 | return s1_logits, s2_logits 277 | 278 | def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None): 279 | """ 280 | Decodes only the s1 tokens. 281 | 282 | This method performs a forward pass to predict only s1 tokens. It returns the s1 logits 283 | and the context representation from the Transformer, which can be used for subsequent s2 decoding. 284 | 285 | Args: 286 | s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] 287 | s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] 288 | stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. 289 | padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. 290 | 291 | Returns: 292 | Tuple[torch.Tensor, torch.Tensor]: 293 | - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] 294 | - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model] 295 | """ 296 | x = self.embedding([s1_ids, s2_ids]) 297 | if stamp is not None: 298 | time_embedding = self.time_emb(stamp) 299 | x = x + time_embedding 300 | x = self.token_drop(x) 301 | 302 | for layer in self.transformer: 303 | x = layer(x, key_padding_mask=padding_mask) 304 | 305 | x = self.norm(x) 306 | 307 | s1_logits = self.head(x) 308 | return s1_logits, x 309 | 310 | def decode_s2(self, context, s1_ids, padding_mask=None): 311 | """ 312 | Decodes the s2 tokens, conditioned on the context and s1 tokens. 313 | 314 | This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`) 315 | and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens. 316 | 317 | Args: 318 | context (torch.Tensor): Context representation from the transformer (output of decode_s1). 319 | Shape: [batch_size, seq_len, d_model] 320 | s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] 321 | padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. 322 | 323 | Returns: 324 | torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size] 325 | """ 326 | sibling_embed = self.embedding.emb_s1(s1_ids) 327 | x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask) 328 | return self.head.cond_forward(x2) 329 | 330 | 331 | def top_k_top_p_filtering( 332 | logits, 333 | top_k: int = 0, 334 | top_p: float = 1.0, 335 | filter_value: float = -float("Inf"), 336 | min_tokens_to_keep: int = 1, 337 | ): 338 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 339 | Args: 340 | logits: logits distribution shape (batch size, vocabulary size) 341 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 342 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 343 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 344 | Make sure we keep at least min_tokens_to_keep per batch example in the output 345 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 346 | """ 347 | if top_k > 0: 348 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 349 | # Remove all tokens with a probability less than the last token of the top-k 350 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 351 | logits[indices_to_remove] = filter_value 352 | return logits 353 | 354 | if top_p < 1.0: 355 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 356 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 357 | 358 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 359 | sorted_indices_to_remove = cumulative_probs > top_p 360 | if min_tokens_to_keep > 1: 361 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 362 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 363 | # Shift the indices to the right to keep also the first token above the threshold 364 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 365 | sorted_indices_to_remove[..., 0] = 0 366 | 367 | # scatter sorted tensors to original indexing 368 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 369 | logits[indices_to_remove] = filter_value 370 | return logits 371 | 372 | 373 | def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True): 374 | logits = logits / temperature 375 | if top_k is not None or top_p is not None: 376 | if top_k > 0 or top_p < 1.0: 377 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 378 | 379 | probs = F.softmax(logits, dim=-1) 380 | 381 | if not sample_logits: 382 | _, x = top_k(probs, k=1, dim=-1) 383 | else: 384 | x = torch.multinomial(probs, num_samples=1) 385 | 386 | return x 387 | 388 | 389 | def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, progress_callback=None): 390 | with torch.no_grad(): 391 | batch_size = x.size(0) 392 | initial_seq_len = x.size(1) 393 | x = torch.clip(x, -clip, clip) 394 | 395 | device = x.device 396 | x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device) 397 | x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device) 398 | y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device) 399 | 400 | x_token = tokenizer.encode(x, half=True) 401 | 402 | def get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, pred_step): 403 | 404 | if current_seq_len <= max_context - pred_step: 405 | return torch.cat([x_stamp, y_stamp[:, :pred_step, :]], dim=1) 406 | else: 407 | start_idx = max_context - pred_step 408 | return torch.cat([x_stamp[:, -start_idx:, :], y_stamp[:, :pred_step, :]], dim=1) 409 | 410 | if verbose: 411 | ran = trange 412 | else: 413 | ran = range 414 | for i in ran(pred_len): 415 | # --- CALLBACK INVOCATION --- 416 | if progress_callback: 417 | progress_callback(i + 1, pred_len) 418 | # ------------------------- 419 | 420 | current_seq_len = initial_seq_len + i 421 | 422 | if current_seq_len <= max_context: 423 | input_tokens = x_token 424 | else: 425 | input_tokens = [t[:, -max_context:].contiguous() for t in x_token] 426 | 427 | current_stamp = get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, i) 428 | 429 | s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp) 430 | s1_logits = s1_logits[:, -1, :] 431 | sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) 432 | 433 | s2_logits = model.decode_s2(context, sample_pre) 434 | s2_logits = s2_logits[:, -1, :] 435 | sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) 436 | 437 | x_token[0] = torch.cat([x_token[0], sample_pre], dim=1) 438 | x_token[1] = torch.cat([x_token[1], sample_post], dim=1) 439 | 440 | input_tokens = [t[:, -max_context:].contiguous() for t in x_token] 441 | z = tokenizer.decode(input_tokens, half=True) 442 | z = z.reshape(batch_size, sample_count, z.size(1), z.size(2)) 443 | preds = z.cpu().numpy() 444 | preds = np.mean(preds, axis=1) 445 | 446 | return preds 447 | 448 | 449 | def calc_time_stamps(x_timestamp): 450 | time_df = pd.DataFrame() 451 | time_df['minute'] = x_timestamp.dt.minute 452 | time_df['hour'] = x_timestamp.dt.hour 453 | time_df['weekday'] = x_timestamp.dt.weekday 454 | time_df['day'] = x_timestamp.dt.day 455 | time_df['month'] = x_timestamp.dt.month 456 | return time_df 457 | 458 | 459 | class KronosPredictor: 460 | 461 | def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5): 462 | self.tokenizer = tokenizer 463 | self.model = model 464 | self.max_context = max_context 465 | self.clip = clip 466 | self.price_cols = ['open', 'high', 'low', 'close'] 467 | self.vol_col = 'volume' 468 | self.amt_vol = 'amount' 469 | self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month'] 470 | self.device = device 471 | 472 | self.tokenizer = self.tokenizer.to(self.device) 473 | self.model = self.model.to(self.device) 474 | 475 | def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, progress_callback=None): 476 | 477 | x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device) 478 | x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device) 479 | y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device) 480 | 481 | preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len, 482 | self.clip, T, top_k, top_p, sample_count, verbose, progress_callback) 483 | preds = preds[:, -pred_len:, :] 484 | return preds 485 | 486 | def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, progress_callback=None): 487 | 488 | if not isinstance(df, pd.DataFrame): 489 | raise ValueError("Input must be a pandas DataFrame.") 490 | 491 | if not all(col in df.columns for col in self.price_cols): 492 | raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.") 493 | 494 | df = df.copy() 495 | if self.vol_col not in df.columns: 496 | df[self.vol_col] = 0.0 # Fill missing volume with zeros 497 | df[self.amt_vol] = 0.0 # Fill missing amount with zeros 498 | if self.amt_vol not in df.columns and self.vol_col in df.columns: 499 | df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) 500 | 501 | if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): 502 | raise ValueError("Input DataFrame contains NaN values in price or volume columns.") 503 | 504 | x_time_df = calc_time_stamps(x_timestamp) 505 | y_time_df = calc_time_stamps(y_timestamp) 506 | 507 | x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) 508 | x_stamp = x_time_df.values.astype(np.float32) 509 | y_stamp = y_time_df.values.astype(np.float32) 510 | 511 | x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) 512 | 513 | x = (x - x_mean) / (x_std + 1e-5) 514 | x = np.clip(x, -self.clip, self.clip) 515 | 516 | x = x[np.newaxis, :] 517 | x_stamp = x_stamp[np.newaxis, :] 518 | y_stamp = y_stamp[np.newaxis, :] 519 | 520 | preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, progress_callback) 521 | 522 | preds = preds.squeeze(0) 523 | preds = preds * (x_std + 1e-5) + x_mean 524 | 525 | pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp) 526 | return pred_df 527 | -------------------------------------------------------------------------------- /facecat/model/kronos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | from huggingface_hub import PyTorchModelHubMixin 5 | import sys 6 | 7 | from tqdm import trange 8 | 9 | sys.path.append("../") 10 | from model.module import * 11 | 12 | 13 | class KronosTokenizer(nn.Module, PyTorchModelHubMixin): 14 | """ 15 | KronosTokenizer module for tokenizing input data using a hybrid quantization approach. 16 | 17 | This tokenizer utilizes a combination of encoder and decoder Transformer blocks 18 | along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data. 19 | 20 | Args: 21 | d_in (int): Input dimension. 22 | d_model (int): Model dimension. 23 | n_heads (int): Number of attention heads. 24 | ff_dim (int): Feed-forward dimension. 25 | n_enc_layers (int): Number of encoder layers. 26 | n_dec_layers (int): Number of decoder layers. 27 | ffn_dropout_p (float): Dropout probability for feed-forward networks. 28 | attn_dropout_p (float): Dropout probability for attention mechanisms. 29 | resid_dropout_p (float): Dropout probability for residual connections. 30 | s1_bits (int): Number of bits for the pre token in BSQuantizer. 31 | s2_bits (int): Number of bits for the post token in BSQuantizer. 32 | beta (float): Beta parameter for BSQuantizer. 33 | gamma0 (float): Gamma0 parameter for BSQuantizer. 34 | gamma (float): Gamma parameter for BSQuantizer. 35 | zeta (float): Zeta parameter for BSQuantizer. 36 | group_size (int): Group size parameter for BSQuantizer. 37 | 38 | """ 39 | 40 | def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, ffn_dropout_p, attn_dropout_p, resid_dropout_p, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size): 41 | 42 | super().__init__() 43 | self.d_in = d_in 44 | self.d_model = d_model 45 | self.n_heads = n_heads 46 | self.ff_dim = ff_dim 47 | self.enc_layers = n_enc_layers 48 | self.dec_layers = n_dec_layers 49 | self.ffn_dropout_p = ffn_dropout_p 50 | self.attn_dropout_p = attn_dropout_p 51 | self.resid_dropout_p = resid_dropout_p 52 | 53 | self.s1_bits = s1_bits 54 | self.s2_bits = s2_bits 55 | self.codebook_dim = s1_bits + s2_bits # Total dimension of the codebook after quantization 56 | self.embed = nn.Linear(self.d_in, self.d_model) 57 | self.head = nn.Linear(self.d_model, self.d_in) 58 | 59 | # Encoder Transformer Blocks 60 | self.encoder = nn.ModuleList([ 61 | TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) 62 | for _ in range(self.enc_layers - 1) 63 | ]) 64 | # Decoder Transformer Blocks 65 | self.decoder = nn.ModuleList([ 66 | TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) 67 | for _ in range(self.dec_layers - 1) 68 | ]) 69 | self.quant_embed = nn.Linear(in_features=self.d_model, out_features=self.codebook_dim) # Linear layer before quantization 70 | self.post_quant_embed_pre = nn.Linear(in_features=self.s1_bits, out_features=self.d_model) # Linear layer after quantization (pre part - s1 bits) 71 | self.post_quant_embed = nn.Linear(in_features=self.codebook_dim, out_features=self.d_model) # Linear layer after quantization (full codebook) 72 | self.tokenizer = BSQuantizer(self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size) # BSQuantizer module 73 | 74 | def forward(self, x): 75 | """ 76 | Forward pass of the KronosTokenizer. 77 | 78 | Args: 79 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). 80 | 81 | Returns: 82 | tuple: A tuple containing: 83 | - tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively, 84 | both of shape (batch_size, seq_len, d_in). 85 | - torch.Tensor: bsq_loss - Loss from the BSQuantizer. 86 | - torch.Tensor: quantized - Quantized representation from BSQuantizer. 87 | - torch.Tensor: z_indices - Indices from the BSQuantizer. 88 | """ 89 | z = self.embed(x) 90 | 91 | for layer in self.encoder: 92 | z = layer(z) 93 | 94 | z = self.quant_embed(z) # (B, T, codebook) 95 | 96 | bsq_loss, quantized, z_indices = self.tokenizer(z) 97 | 98 | quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits) 99 | z_pre = self.post_quant_embed_pre(quantized_pre) 100 | 101 | z = self.post_quant_embed(quantized) 102 | 103 | # Decoder layers (for pre part - s1 bits) 104 | for layer in self.decoder: 105 | z_pre = layer(z_pre) 106 | z_pre = self.head(z_pre) 107 | 108 | # Decoder layers (for full codebook) 109 | for layer in self.decoder: 110 | z = layer(z) 111 | z = self.head(z) 112 | 113 | return (z_pre, z), bsq_loss, quantized, z_indices 114 | 115 | def indices_to_bits(self, x, half=False): 116 | """ 117 | Converts indices to bit representations and scales them. 118 | 119 | Args: 120 | x (torch.Tensor): Indices tensor. 121 | half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False. 122 | 123 | Returns: 124 | torch.Tensor: Bit representation tensor. 125 | """ 126 | if half: 127 | x1 = x[0] # Assuming x is a tuple of indices if half is True 128 | x2 = x[1] 129 | mask = 2 ** torch.arange(self.codebook_dim//2, device=x1.device, dtype=torch.long) # Create a mask for bit extraction 130 | x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half 131 | x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half 132 | x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations 133 | else: 134 | mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # Create a mask for bit extraction 135 | x = (x.unsqueeze(-1) & mask) != 0 # Extract bits 136 | 137 | x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1) 138 | q_scale = 1. / (self.codebook_dim ** 0.5) # Scaling factor 139 | x = x * q_scale 140 | return x 141 | 142 | def encode(self, x, half=False): 143 | """ 144 | Encodes the input data into quantized indices. 145 | 146 | Args: 147 | x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in). 148 | half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False. 149 | 150 | Returns: 151 | torch.Tensor: Quantized indices from BSQuantizer. 152 | """ 153 | z = self.embed(x) 154 | for layer in self.encoder: 155 | z = layer(z) 156 | z = self.quant_embed(z) 157 | 158 | bsq_loss, quantized, z_indices = self.tokenizer(z, half) 159 | return z_indices 160 | 161 | def decode(self, x, half=False): 162 | """ 163 | Decodes quantized indices back to the input data space. 164 | 165 | Args: 166 | x (torch.Tensor): Quantized indices tensor. 167 | half (bool, optional): Whether the indices were generated with half quantization. Defaults to False. 168 | 169 | Returns: 170 | torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in). 171 | """ 172 | quantized = self.indices_to_bits(x, half) 173 | z = self.post_quant_embed(quantized) 174 | for layer in self.decoder: 175 | z = layer(z) 176 | z = self.head(z) 177 | return z 178 | 179 | 180 | class Kronos(nn.Module, PyTorchModelHubMixin): 181 | """ 182 | Kronos Model. 183 | 184 | Args: 185 | s1_bits (int): Number of bits for pre tokens. 186 | s2_bits (int): Number of bits for post tokens. 187 | n_layers (int): Number of Transformer blocks. 188 | d_model (int): Dimension of the model's embeddings and hidden states. 189 | n_heads (int): Number of attention heads in the MultiheadAttention layers. 190 | ff_dim (int): Dimension of the feedforward network in the Transformer blocks. 191 | ffn_dropout_p (float): Dropout probability for the feedforward network. 192 | attn_dropout_p (float): Dropout probability for the attention layers. 193 | resid_dropout_p (float): Dropout probability for residual connections. 194 | token_dropout_p (float): Dropout probability for token embeddings. 195 | learn_te (bool): Whether to use learnable temporal embeddings. 196 | """ 197 | 198 | def __init__(self, s1_bits, s2_bits, n_layers, d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p, token_dropout_p, learn_te): 199 | super().__init__() 200 | self.s1_bits = s1_bits 201 | self.s2_bits = s2_bits 202 | self.n_layers = n_layers 203 | self.d_model = d_model 204 | self.n_heads = n_heads 205 | self.learn_te = learn_te 206 | self.ff_dim = ff_dim 207 | self.ffn_dropout_p = ffn_dropout_p 208 | self.attn_dropout_p = attn_dropout_p 209 | self.resid_dropout_p = resid_dropout_p 210 | self.token_dropout_p = token_dropout_p 211 | 212 | self.s1_vocab_size = 2 ** self.s1_bits 213 | self.token_drop = nn.Dropout(self.token_dropout_p) 214 | self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model) 215 | self.time_emb = TemporalEmbedding(self.d_model, self.learn_te) 216 | self.transformer = nn.ModuleList([ 217 | TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p) 218 | for _ in range(self.n_layers) 219 | ]) 220 | self.norm = RMSNorm(self.d_model) 221 | self.dep_layer = DependencyAwareLayer(self.d_model) 222 | self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model) 223 | self.apply(self._init_weights) 224 | 225 | def _init_weights(self, module): 226 | 227 | if isinstance(module, nn.Linear): 228 | nn.init.xavier_normal_(module.weight) 229 | if module.bias is not None: 230 | nn.init.zeros_(module.bias) 231 | elif isinstance(module, nn.Embedding): 232 | nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model ** -0.5) 233 | elif isinstance(module, nn.LayerNorm): 234 | nn.init.ones_(module.weight) 235 | nn.init.zeros_(module.bias) 236 | elif isinstance(module, RMSNorm): 237 | nn.init.ones_(module.weight) 238 | 239 | def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None): 240 | """ 241 | Args: 242 | s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] 243 | s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] 244 | stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. 245 | padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. 246 | use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False. 247 | s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None. 248 | 249 | Returns: 250 | Tuple[torch.Tensor, torch.Tensor]: 251 | - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] 252 | - s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size] 253 | """ 254 | x = self.embedding([s1_ids, s2_ids]) 255 | if stamp is not None: 256 | time_embedding = self.time_emb(stamp) 257 | x = x + time_embedding 258 | x = self.token_drop(x) 259 | 260 | for layer in self.transformer: 261 | x = layer(x, key_padding_mask=padding_mask) 262 | 263 | x = self.norm(x) 264 | 265 | s1_logits = self.head(x) 266 | 267 | if use_teacher_forcing: 268 | sibling_embed = self.embedding.emb_s1(s1_targets) 269 | else: 270 | s1_probs = F.softmax(s1_logits.detach(), dim=-1) 271 | sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape) 272 | sibling_embed = self.embedding.emb_s1(sample_s1_ids) 273 | 274 | x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings 275 | s2_logits = self.head.cond_forward(x2) 276 | return s1_logits, s2_logits 277 | 278 | def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None): 279 | """ 280 | Decodes only the s1 tokens. 281 | 282 | This method performs a forward pass to predict only s1 tokens. It returns the s1 logits 283 | and the context representation from the Transformer, which can be used for subsequent s2 decoding. 284 | 285 | Args: 286 | s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] 287 | s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] 288 | stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. 289 | padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. 290 | 291 | Returns: 292 | Tuple[torch.Tensor, torch.Tensor]: 293 | - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] 294 | - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model] 295 | """ 296 | x = self.embedding([s1_ids, s2_ids]) 297 | if stamp is not None: 298 | time_embedding = self.time_emb(stamp) 299 | x = x + time_embedding 300 | x = self.token_drop(x) 301 | 302 | for layer in self.transformer: 303 | x = layer(x, key_padding_mask=padding_mask) 304 | 305 | x = self.norm(x) 306 | 307 | s1_logits = self.head(x) 308 | return s1_logits, x 309 | 310 | def decode_s2(self, context, s1_ids, padding_mask=None): 311 | """ 312 | Decodes the s2 tokens, conditioned on the context and s1 tokens. 313 | 314 | This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`) 315 | and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens. 316 | 317 | Args: 318 | context (torch.Tensor): Context representation from the transformer (output of decode_s1). 319 | Shape: [batch_size, seq_len, d_model] 320 | s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] 321 | padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. 322 | 323 | Returns: 324 | torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size] 325 | """ 326 | sibling_embed = self.embedding.emb_s1(s1_ids) 327 | x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask) 328 | return self.head.cond_forward(x2) 329 | 330 | 331 | def top_k_top_p_filtering( 332 | logits, 333 | top_k: int = 0, 334 | top_p: float = 1.0, 335 | filter_value: float = -float("Inf"), 336 | min_tokens_to_keep: int = 1, 337 | ): 338 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 339 | Args: 340 | logits: logits distribution shape (batch size, vocabulary size) 341 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 342 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 343 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 344 | Make sure we keep at least min_tokens_to_keep per batch example in the output 345 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 346 | """ 347 | if top_k > 0: 348 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 349 | # Remove all tokens with a probability less than the last token of the top-k 350 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 351 | logits[indices_to_remove] = filter_value 352 | return logits 353 | 354 | if top_p < 1.0: 355 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 356 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 357 | 358 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 359 | sorted_indices_to_remove = cumulative_probs > top_p 360 | if min_tokens_to_keep > 1: 361 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 362 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 363 | # Shift the indices to the right to keep also the first token above the threshold 364 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 365 | sorted_indices_to_remove[..., 0] = 0 366 | 367 | # scatter sorted tensors to original indexing 368 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 369 | logits[indices_to_remove] = filter_value 370 | return logits 371 | 372 | 373 | def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True): 374 | logits = logits / temperature 375 | if top_k is not None or top_p is not None: 376 | if top_k > 0 or top_p < 1.0: 377 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 378 | 379 | probs = F.softmax(logits, dim=-1) 380 | 381 | if not sample_logits: 382 | _, x = top_k(probs, k=1, dim=-1) 383 | else: 384 | x = torch.multinomial(probs, num_samples=1) 385 | 386 | return x 387 | 388 | 389 | def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, progress_callback=None): 390 | with torch.no_grad(): 391 | batch_size = x.size(0) 392 | initial_seq_len = x.size(1) 393 | x = torch.clip(x, -clip, clip) 394 | 395 | device = x.device 396 | x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device) 397 | x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device) 398 | y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device) 399 | 400 | x_token = tokenizer.encode(x, half=True) 401 | 402 | def get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, pred_step): 403 | 404 | if current_seq_len <= max_context - pred_step: 405 | return torch.cat([x_stamp, y_stamp[:, :pred_step, :]], dim=1) 406 | else: 407 | start_idx = max_context - pred_step 408 | return torch.cat([x_stamp[:, -start_idx:, :], y_stamp[:, :pred_step, :]], dim=1) 409 | 410 | if verbose: 411 | ran = trange 412 | else: 413 | ran = range 414 | for i in ran(pred_len): 415 | # --- CALLBACK INVOCATION --- 416 | if progress_callback: 417 | progress_callback(i + 1, pred_len) 418 | # ------------------------- 419 | 420 | current_seq_len = initial_seq_len + i 421 | 422 | if current_seq_len <= max_context: 423 | input_tokens = x_token 424 | else: 425 | input_tokens = [t[:, -max_context:].contiguous() for t in x_token] 426 | 427 | current_stamp = get_dynamic_stamp(x_stamp, y_stamp, current_seq_len, i) 428 | 429 | s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp) 430 | s1_logits = s1_logits[:, -1, :] 431 | sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) 432 | 433 | s2_logits = model.decode_s2(context, sample_pre) 434 | s2_logits = s2_logits[:, -1, :] 435 | sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) 436 | 437 | x_token[0] = torch.cat([x_token[0], sample_pre], dim=1) 438 | x_token[1] = torch.cat([x_token[1], sample_post], dim=1) 439 | 440 | input_tokens = [t[:, -max_context:].contiguous() for t in x_token] 441 | z = tokenizer.decode(input_tokens, half=True) 442 | z = z.reshape(batch_size, sample_count, z.size(1), z.size(2)) 443 | preds = z.cpu().numpy() 444 | preds = np.mean(preds, axis=1) 445 | 446 | return preds 447 | 448 | 449 | def calc_time_stamps(x_timestamp): 450 | time_df = pd.DataFrame() 451 | time_df['minute'] = x_timestamp.dt.minute 452 | time_df['hour'] = x_timestamp.dt.hour 453 | time_df['weekday'] = x_timestamp.dt.weekday 454 | time_df['day'] = x_timestamp.dt.day 455 | time_df['month'] = x_timestamp.dt.month 456 | return time_df 457 | 458 | 459 | class KronosPredictor: 460 | 461 | def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5): 462 | self.tokenizer = tokenizer 463 | self.model = model 464 | self.max_context = max_context 465 | self.clip = clip 466 | self.price_cols = ['open', 'high', 'low', 'close'] 467 | self.vol_col = 'volume' 468 | self.amt_vol = 'amount' 469 | self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month'] 470 | self.device = device 471 | 472 | self.tokenizer = self.tokenizer.to(self.device) 473 | self.model = self.model.to(self.device) 474 | 475 | def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, progress_callback=None): 476 | 477 | x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device) 478 | x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device) 479 | y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device) 480 | 481 | preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len, 482 | self.clip, T, top_k, top_p, sample_count, verbose, progress_callback) 483 | preds = preds[:, -pred_len:, :] 484 | return preds 485 | 486 | def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, progress_callback=None): 487 | 488 | if not isinstance(df, pd.DataFrame): 489 | raise ValueError("Input must be a pandas DataFrame.") 490 | 491 | if not all(col in df.columns for col in self.price_cols): 492 | raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.") 493 | 494 | df = df.copy() 495 | if self.vol_col not in df.columns: 496 | df[self.vol_col] = 0.0 # Fill missing volume with zeros 497 | df[self.amt_vol] = 0.0 # Fill missing amount with zeros 498 | if self.amt_vol not in df.columns and self.vol_col in df.columns: 499 | df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1) 500 | 501 | if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any(): 502 | raise ValueError("Input DataFrame contains NaN values in price or volume columns.") 503 | 504 | x_time_df = calc_time_stamps(x_timestamp) 505 | y_time_df = calc_time_stamps(y_timestamp) 506 | 507 | x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32) 508 | x_stamp = x_time_df.values.astype(np.float32) 509 | y_stamp = y_time_df.values.astype(np.float32) 510 | 511 | x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) 512 | 513 | x = (x - x_mean) / (x_std + 1e-5) 514 | x = np.clip(x, -self.clip, self.clip) 515 | 516 | x = x[np.newaxis, :] 517 | x_stamp = x_stamp[np.newaxis, :] 518 | y_stamp = y_stamp[np.newaxis, :] 519 | 520 | preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, progress_callback) 521 | 522 | preds = preds.squeeze(0) 523 | preds = preds * (x_std + 1e-5) + x_mean 524 | 525 | pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp) 526 | return pred_df 527 | --------------------------------------------------------------------------------