├── models ├── __init__.py ├── __pycache__ │ ├── darts.cpython-39.pyc │ ├── gpt.cpython-39.pyc │ ├── llama.cpython-39.pyc │ ├── llms.cpython-39.pyc │ ├── utils.cpython-39.pyc │ ├── llmtime.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── promptcast.cpython-39.pyc │ ├── gaussian_process.cpython-39.pyc │ └── validation_likelihood_tuning.cpython-39.pyc ├── leftover_llama_hypers.py ├── llms.py ├── gaussian_process.py ├── llama.py ├── validation_likelihood_tuning.py ├── utils.py ├── gpt.py ├── darts.py ├── llmtime.py └── promptcast.py ├── Images └── Workflow.png ├── data1 ├── __pycache__ │ ├── metrics.cpython-39.pyc │ ├── serialize.cpython-39.pyc │ └── small_context.cpython-39.pyc ├── synthetic.py ├── last_val_mae.csv ├── last_value_results.csv ├── paper_mae_raw.csv ├── paper_mae.csv ├── small_context.py ├── monash.py ├── paper_mae_normalized.csv ├── metrics.py ├── serialize.py └── autoformer_dataset.py ├── Experiment_results ├── figures_tmp │ ├── Sample_multi.pdf │ ├── TurkeyPowerLLMTime GPT-41.pdf │ ├── TurkeyPowerLLMTime GPT-42.pdf │ ├── TurkeyPowerLLMTime GPT-3.51.pdf │ ├── TurkeyPowerLLMTime GPT-3.52.pdf │ ├── TurkeyPowerLinear_regression1.pdf │ ├── TurkeyPowerLinear_regression2.pdf │ ├── WineDataset_gemini-1.0-pro_prediction.pdf │ ├── synthesized_dataset_GPT-4_TurkeyPower.pdf │ ├── Counterfactual Analysis_0407_WineDataset.png │ ├── Counterfactual Analysis_0407_IstanbulTraffic.png │ ├── Counterfactual Analysis_0407_AirPassengersDataset.png │ ├── Length_of_Training_Set Analysis_0407_WineDataset.png │ ├── WineDatasetgemini-1.0-pro_prediction_meticulous.pdf │ ├── Length_of_Training_Set Analysis_0407_AusBeerDataset.png │ ├── WineDataset_gemini-1.0-pro_prediction_with_blanket.pdf │ ├── Length_of_Training_Set Analysis_0407_IstanbulTraffic.png │ ├── Length_of_Training_Set Analysis_0407_MonthlyMilkDataset.png │ ├── Length_of_Training_Set Analysis_0407_AirPassengersDataset.png │ └── WineDatasetgemini-1.0-pro_prediction_meticulous_with_blanket.pdf └── Correlation_matrix │ ├── correlation_matrix.csv │ └── correlation_matrix_syn.csv ├── requirements.txt ├── config.json ├── Time_llm_results.csv ├── strength_comparison_GPT-3.5.csv ├── strength_comparison_GPT-4.csv ├── README.md ├── 0_baseline_experiment_w_gemini.ipynb ├── 4_counterfactual_analysis.ipynb ├── utils_paragraph.py ├── utils_others.py └── 6_paraphrase_and_predict.ipynb /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Images/Workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Images/Workflow.png -------------------------------------------------------------------------------- /models/__pycache__/darts.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/models/__pycache__/darts.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/gpt.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/models/__pycache__/gpt.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/llama.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/models/__pycache__/llama.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/llms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/models/__pycache__/llms.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/models/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /data1/__pycache__/metrics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/data1/__pycache__/metrics.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/llmtime.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/models/__pycache__/llmtime.cpython-39.pyc -------------------------------------------------------------------------------- /data1/__pycache__/serialize.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/data1/__pycache__/serialize.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/promptcast.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/models/__pycache__/promptcast.cpython-39.pyc -------------------------------------------------------------------------------- /data1/__pycache__/small_context.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/data1/__pycache__/small_context.cpython-39.pyc -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/Sample_multi.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/Sample_multi.pdf -------------------------------------------------------------------------------- /models/__pycache__/gaussian_process.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/models/__pycache__/gaussian_process.cpython-39.pyc -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/TurkeyPowerLLMTime GPT-41.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/TurkeyPowerLLMTime GPT-41.pdf -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/TurkeyPowerLLMTime GPT-42.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/TurkeyPowerLLMTime GPT-42.pdf -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/TurkeyPowerLLMTime GPT-3.51.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/TurkeyPowerLLMTime GPT-3.51.pdf -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/TurkeyPowerLLMTime GPT-3.52.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/TurkeyPowerLLMTime GPT-3.52.pdf -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/TurkeyPowerLinear_regression1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/TurkeyPowerLinear_regression1.pdf -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/TurkeyPowerLinear_regression2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/TurkeyPowerLinear_regression2.pdf -------------------------------------------------------------------------------- /models/__pycache__/validation_likelihood_tuning.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/models/__pycache__/validation_likelihood_tuning.cpython-39.pyc -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/WineDataset_gemini-1.0-pro_prediction.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/WineDataset_gemini-1.0-pro_prediction.pdf -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/synthesized_dataset_GPT-4_TurkeyPower.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/synthesized_dataset_GPT-4_TurkeyPower.pdf -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/Counterfactual Analysis_0407_WineDataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/Counterfactual Analysis_0407_WineDataset.png -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/Counterfactual Analysis_0407_IstanbulTraffic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/Counterfactual Analysis_0407_IstanbulTraffic.png -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/Counterfactual Analysis_0407_AirPassengersDataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/Counterfactual Analysis_0407_AirPassengersDataset.png -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/Length_of_Training_Set Analysis_0407_WineDataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/Length_of_Training_Set Analysis_0407_WineDataset.png -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/WineDatasetgemini-1.0-pro_prediction_meticulous.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/WineDatasetgemini-1.0-pro_prediction_meticulous.pdf -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/Length_of_Training_Set Analysis_0407_AusBeerDataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/Length_of_Training_Set Analysis_0407_AusBeerDataset.png -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/WineDataset_gemini-1.0-pro_prediction_with_blanket.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/WineDataset_gemini-1.0-pro_prediction_with_blanket.pdf -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/Length_of_Training_Set Analysis_0407_IstanbulTraffic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/Length_of_Training_Set Analysis_0407_IstanbulTraffic.png -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/Length_of_Training_Set Analysis_0407_MonthlyMilkDataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/Length_of_Training_Set Analysis_0407_MonthlyMilkDataset.png -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/Length_of_Training_Set Analysis_0407_AirPassengersDataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/Length_of_Training_Set Analysis_0407_AirPassengersDataset.png -------------------------------------------------------------------------------- /Experiment_results/figures_tmp/WineDatasetgemini-1.0-pro_prediction_meticulous_with_blanket.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingyuJ666/Time-Series-Forecasting-with-LLMs/HEAD/Experiment_results/figures_tmp/WineDatasetgemini-1.0-pro_prediction_meticulous_with_blanket.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | pandas 3 | numpy 4 | seaborn 5 | google-generativeai 6 | scikit-learn 7 | openai<=0.28.1 8 | darts 9 | statsmodels 10 | tiktoken 11 | tqdm 12 | gpytorch 13 | transformers 14 | datasets 15 | multiprocess 16 | SentencePiece 17 | accelerate 18 | gdown 19 | scipy -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "OPENAI_API_KEY": "[YOUR OPENAI_API_KEY]", 3 | "OPENAI_BASE_URL": "[YOUR OPENAI_BASE_URL]", 4 | "GEMINI_API_KEY": "[YOUR GEMINI_API_KEY]", 5 | "GEMINI_BASE_URL": "[YOUR GEMINI_BASE_URL]", 6 | "LLAMA_API_KEY": "[YOUR LLAMA_API_KEY]", 7 | "LLAMA_BASE_URL": "[YOUR LLAMA_BASE_URL]" 8 | } -------------------------------------------------------------------------------- /Experiment_results/Correlation_matrix/correlation_matrix.csv: -------------------------------------------------------------------------------- 1 | GPT4-MAPE,GPT4-R^2,GPT3.5-MAPE,GPT3.5-R^2,trend_strength,seas_strength 2 | 1.0,-0.18637491085890565,0.9873979984914336,-0.3573000053467488,-0.020636700627533585,-0.6814399373699493 3 | -0.18637491085890565,1.0,-0.28320639975853806,0.8286954209993082,0.5755840259600259,0.4304714811270461 4 | 0.9873979984914336,-0.28320639975853806,1.0,-0.43348317814331405,-0.11508704607490433,-0.6699829113249919 5 | -0.3573000053467488,0.8286954209993082,-0.43348317814331405,1.0,0.4883568955966373,0.5970887994611797 6 | -0.020636700627533585,0.5755840259600259,-0.11508704607490433,0.4883568955966373,1.0,0.5089804654550105 7 | -0.6814399373699493,0.4304714811270461,-0.6699829113249919,0.5970887994611797,0.5089804654550105,1.0 8 | -------------------------------------------------------------------------------- /Time_llm_results.csv: -------------------------------------------------------------------------------- 1 | Dataset Name,GPT4-MAPE,GPT4-R^2,GPT3.5-MAPE,GPT3.5-R^2,trend_strength,seas_strength 2 | AirPassengersDataset,6.8,0.79,9.98,0.32,0.997409295,0.984544471 3 | AusBeerDataset,3.69,0.78,5.12,0.57,0.987717269,0.961133656 4 | MonthlyMilkDataset,5.12,0.38,6.25,-0.34,0.996473303,0.993949709 5 | SunspotsDataset,334.3,-0.43,194.29,-1.21,0.810361724,0.275599419 6 | WineDataset,10.9,0.49,14.98,0.11,0.672305082,0.920319441 7 | WoolyDataset,20.41,-1.74,19.26,-1.42,0.959486026,0.822665482 8 | IstanbulTrafficGPT,47.29,-1.96,60.11,-1.75,0.314842588,0.717737373 9 | GasRateCO2Dataset,4.21,-0.05,5.97,-1.47,0.64601373,0.503739374 10 | HeartRateDataset,7.9,-0.85,6.75,-0.4,0.417927681,0.488174248 11 | TurkeyPower,3.36,0.71,3.52,0.76,0.901619187,0.875574189 12 | -------------------------------------------------------------------------------- /Experiment_results/Correlation_matrix/correlation_matrix_syn.csv: -------------------------------------------------------------------------------- 1 | trend_strength,seasonal_strength,R^2_GPT_3,R^2_GPT_4,MAPE_GPT_3,MAPE_GPT_4 2 | 1.0,2.4494820473748718e-17,-0.3230931024259871,-0.2358930282830184,-0.6355181747946096,-0.6333989310826924 3 | 2.4494820473748718e-17,1.0,0.36691993940701023,0.5236949563494339,0.2847153378224841,0.2440273209622527 4 | -0.3230931024259871,0.36691993940701023,1.0,0.2619135902462044,0.20396302042517828,0.22146332230143714 5 | -0.2358930282830184,0.5236949563494339,0.2619135902462044,1.0,0.25482176144898105,0.1936105344631017 6 | -0.6355181747946096,0.2847153378224841,0.20396302042517828,0.25482176144898105,1.0,0.9893825441976956 7 | -0.6333989310826924,0.2440273209622527,0.22146332230143714,0.1936105344631017,0.9893825441976956,1.0 8 | -------------------------------------------------------------------------------- /data1/synthetic.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import pandas as pd 5 | 6 | def get_synthetic_datasets(): 7 | dss = [] 8 | dir_path = os.path.dirname(os.path.realpath(__file__)) 9 | dss += glob(f'{dir_path}/../datasets/synthetic/*.npy') 10 | x = [] 11 | labels = [] 12 | for ds in dss: 13 | s = np.load(ds)[:3] 14 | x.append(s) 15 | labels += [ds.split('/')[-1].split('.')[0]+(f'_{i}' if len(s)>1 else '') for i in range(len(s))] 16 | x = np.concatenate(x) 17 | # subtract mean 18 | # x -= np.mean(x, axis=1, keepdims=True) 19 | data = [pd.Series(x[i],index=pd.RangeIndex(len(x[i]))) for i in range(len(x))] 20 | synthetic_datasets = {dsname:(dat[:140],dat[140:]) for dsname,dat in zip(labels,data)} 21 | return synthetic_datasets -------------------------------------------------------------------------------- /data1/last_val_mae.csv: -------------------------------------------------------------------------------- 1 | dataset,mae 2 | bitcoin,7.777284173521224e+17 3 | wind_4_seconds,0.0 4 | tourism_quarterly,15845.100306204946 5 | traffic_weekly,1.1855844384623289 6 | us_births,1152.6666666666667 7 | sunspot,3.933333396911621 8 | covid_deaths,353.70939849624057 9 | weather,2.362190193902301 10 | nn5_daily,8.262752532958984 11 | solar_10_minutes,2.7269221451758905 12 | fred_md,2825.672461360778 13 | tourism_yearly,99456.0540551959 14 | oikolab_weather,120.03774162319314 15 | cif_2016,386526.36704240675 16 | solar_4_seconds,0.0 17 | australian_electricity_demand,659.600688770839 18 | traffic_hourly,0.026246391982575817 19 | solar_weekly,1729.4092503457175 20 | tourism_monthly,5636.83029361023 21 | nn5_weekly,16.708553516113007 22 | pedestrian_counts,170.8838383838384 23 | kaggle_web_traffic_weekly,2081.781183003247 24 | hospital,24.06573229030856 25 | saugeenday,21.496667098999023 26 | -------------------------------------------------------------------------------- /strength_comparison_GPT-3.5.csv: -------------------------------------------------------------------------------- 1 | avg_trend_strength,avg_seasonal_strength,median_trend_strength,median_seasonal_strength,test_trend_strength,test_seasonal_strength 2 | 0.721730304406326,0.621241818388548,0.7162911205250939,0.781146599675492,0.98864837090548,0.9919672173330345 3 | 0.8402750772794457,0.9430159359264207,0.48583470293186537,0.9746900512078235,0.8374566970833126,0.9639220059266128 4 | 0.513852224979364,0.7623892107907864,0.36701151385209874,0.6085157319125409,0.7731780722174273,0.9073980236057049 5 | 0.606435552390031,0.6636136782557012,0.7354295510381645,0.7999582827912028,0.9584117563165646,0.99416738027759 6 | 0.9064476488067872,0.41184268788880674,0.957021603706122,0.327851581811026,0.8492924666947155,0.2161210873348205 7 | 0.058792732136425926,0.499139646523569,0.031281532946074275,0.665622777026466,0.1251408208183472,0.9182823833955689 8 | 0.4617301857421003,0.4491424794023309,0.8185093499531757,0.5622837926929916,0.9236267245113854,0.9054543584768221 9 | 0.7474085228611901,0.8492087018253025,0.9700670005175012,0.7914244940234587,0.6312536346153146,0.9057394079424853 10 | 1.0,1.0,1.0,1.0,1.0,1.0 11 | 0.710287583143777,0.9172995815497764,0.8103355854175758,0.9934686981186472,0.7580377254665938,0.9702812660910854 12 | -------------------------------------------------------------------------------- /strength_comparison_GPT-4.csv: -------------------------------------------------------------------------------- 1 | avg_trend_strength,avg_seasonal_strength,median_trend_strength,median_seasonal_strength,test_trend_strength,test_seasonal_strength 2 | 0.8615517415993198,0.8151108807309505,0.9955219336160295,0.987307351356699,0.98864837090548,0.9919672173330345 3 | 0.7725482772743126,0.9796718070106726,0.5547295483996076,0.9971740239622194,0.8374566970833126,0.9639220059266128 4 | 0.765062300825345,0.8266666908278311,0.8223039396770033,0.8761374019114323,0.7731780722174273,0.9073980236057049 5 | 0.9872347403131234,0.9937830101205509,0.9975021011763137,0.9995390086709713,0.9584117563165646,0.99416738027759 6 | 0.7707842088541392,0.3400514678925044,0.692736305456758,0.4533280007802446,0.8492924666947155,0.2161210873348205 7 | 0.4067318008790828,0.7034611174257106,0.22508257819531474,0.6677880166881824,0.1251408208183472,0.9182823833955689 8 | 0.7277904253926225,0.8182285474869987,0.7023337543736485,0.8485550477623117,0.9236267245113854,0.9054543584768221 9 | 0.817088333969475,0.8487550136640204,0.9747208823846725,0.8028140344463537,0.6312536346153146,0.9057394079424853 10 | 1.0,1.0,1.0,1.0,1.0,1.0 11 | 0.8341485518411828,0.9865895458497782,0.9282146175963917,0.998652372077613,0.7580377254665938,0.9702812660910854 12 | -------------------------------------------------------------------------------- /data1/last_value_results.csv: -------------------------------------------------------------------------------- 1 | dataset,rmse,mse,mae 2 | weather,3.114059003829736,19.799946602304583,2.362190193902301 3 | tourism_yearly,111078.12328993063,363513954611.7144,99456.0540551959 4 | tourism_quarterly,19527.771486861628,6201928361.602383,15845.100306204946 5 | tourism_monthly,7374.891631391369,619111237.8884984,5636.83029361023 6 | cif_2016,513908.01521842513,4215886809397.5137,386526.3670424068 7 | australian_electricity_demand,800.8367705189698,885757.1106004094,652.185211690267 8 | pedestrian_counts,201.73714366476048,151208.0053661616,149.78377525252526 9 | nn5_weekly,20.207279217372406,551.324271498317,16.708553516113007 10 | kaggle_web_traffic_weekly,2751.056545255073,745610075.6193119,2081.781183003247 11 | solar_10_minutes,2.916518075950651,15.696517675162108,2.72692214517589 12 | solar_weekly,1918.441068452442,4732343.391918798,1729.4092503457173 13 | fred_md,3128.761597032152,789147915.1258634,2825.672461360779 14 | traffic_hourly,0.038692458579408555,0.0018708090134443467,0.029967592352576376 15 | traffic_weekly,1.5367087177120735,4.087448585855902,1.1855844384623289 16 | hospital,28.901001578961637,5851.362559756627,24.06573229030856 17 | covid_deaths,403.4146883707207,4435194.603884712,353.7093984962405 18 | saugeenday,39.79399009437644,1583.56164763133,21.496667098999023 19 | us_births,1608.8383179590587,2588360.7333333334,1152.6666666666667 20 | solar_4_seconds,0.0,0.0,0.0 21 | wind_4_seconds,0.0,0.0,0.0 22 | oikolab_weather,130.07108228180007,81269.08468305029,120.03774162319314 23 | -------------------------------------------------------------------------------- /models/leftover_llama_hypers.py: -------------------------------------------------------------------------------- 1 | 2 | #missing 3 | # hypers = { 4 | # "base": 10, 5 | # "prec": args.prec, 6 | # "time_sep": args.time_sep, 7 | # "bit_sep": args.bit_sep, 8 | # "missing_str": "NaN", 9 | # } 10 | 11 | 12 | 13 | 14 | 15 | # promptcast_hypers = dict( 16 | # base=10, 17 | # prec=0, 18 | # signed=True, 19 | # time_sep=',', 20 | # bit_sep='', 21 | # plus_sign='', 22 | # minus_sign='-', 23 | # half_bin_correction=False, 24 | # decimal_point='' 25 | # ) 26 | # hypers = promptcast_hypers 27 | 28 | 29 | # beta = 0 # args.beta 30 | # alpha = -1 # args.alpha 31 | # prec = 0 # args.prec 32 | 33 | # for ds_tuple in ds_tuples: 34 | # print(ds_tuple) 35 | 36 | # dsname, train_frac = ds_tuple 37 | 38 | # print(f"Running on {dsname}...") 39 | 40 | # hypers = { 41 | # "base": 10, 42 | # "prec": prec, 43 | # "time_sep": args.time_sep, 44 | # "bit_sep": args.bit_sep, 45 | # "signed": True, 46 | # } 47 | 48 | 49 | 50 | #monash 51 | # hypers = { 52 | # "base": 10, 53 | # "prec": args.prec, 54 | # "time_sep": args.time_sep, 55 | # "bit_sep": args.bit_sep, 56 | # "signed": True, 57 | # } 58 | 59 | 60 | #autoformer 61 | 62 | # dsname, series_num = ds_tuple 63 | 64 | # if dsname == "national_illness.csv": 65 | # test_length = 36 66 | 67 | # df = pd.read_csv( 68 | # f"/private/home/ngruver/time-series-lm/autoformer/{dsname}.csv" 69 | # ) 70 | 71 | # train = df.iloc[:-test_length,series_num] 72 | # test = df.iloc[-test_length:,series_num] 73 | 74 | # hypers = { 75 | # "base": 10, 76 | # "prec": args.prec, 77 | # "time_sep": args.time_sep, 78 | # "bit_sep": args.bit_sep, 79 | # "signed": True, 80 | # } 81 | 82 | 83 | # parser.add_argument("--alpha", type=float, default=0.99) 84 | # parser.add_argument("--beta", type=float, default=0.3) 85 | # parser.add_argument("--prec", type=int, default=3) -------------------------------------------------------------------------------- /data1/paper_mae_raw.csv: -------------------------------------------------------------------------------- 1 | Dataset,SES,Theta,TBATS,ETS,(DHR-)ARIMA,PR,CatBoost,FFNN,DeepAR,N-BEATS,WaveNet,Transformer 2 | Tourism Yearly,95579.23,90653.60,94121.08,94818.89,95033.24,82682.97,79567.22,79593.22,71471.29,70951.80,69905.47,74316.52 3 | Tourism Quarterly,15014.19,7656.49,9972.42,8925.52,10475.47,9092.58,10267.97,8981.04,9511.37,8640.56,9137.12,9521.67 4 | Tourism Monthly,5302.10,2069.96,2940.08,2004.51,2536.77,2187.28,2537.04,2022.21,1871.69,2003.02,2095.13,2146.98 5 | CIF 2016,581875.97,714818.58,855578.40,642421.42,469059.49,563205.57,603551.30,1495923.44,3200418.00,679034.80,5998224.62,4057973.04 6 | Aus. Electricity Demand,659.60,665.04,370.74,1282.99,1045.92,247.18,241.77,258.76,302.41,213.83,227.50,231.45 7 | Dominick,5.70,5.86,7.08,5.81,7.10,8.19,8.09,5.85,5.23,8.28,5.10,5.18 8 | Bitcoin,5.33e18,5.33e18,9.9e17,1.10e18,3.62e18,6.66e17,1.93e18,1.45e18,1.95e18,1.06e18,2.46e18,2.61e18 9 | Pedestrian Counts,170.87,170.94,222.38,216.50,635.16,44.18,43.41,46.41,44.78,66.84,46.46,47.29 10 | Vehicle Trips,29.98,30.76,21.21,30.95,30.07,27.24,22.61,22.93,22.00,28.16,24.15,28.01 11 | KDD Cup,42.04,42.06,39.20,44.88,52.20,36.85,34.82,37.16,48.98,49.10,37.08,44.46 12 | Weather,2.24,2.51,2.30,2.35,2.45,8.17,2.51,2.09,2.02,2.34,2.29,2.03 13 | NN5 Daily,6.63,3.80,3.70,3.72,4.41,5.47,4.22,4.06,3.94,4.92,3.97,4.16 14 | NN5 Weekly,15.66,15.30,14.98,15.70,15.38,14.94,15.29,15.02,14.69,14.19,19.34,20.34 15 | Kaggle Daily,363.43,358.73,415.40,403.23,340.36,-,-,-,-,-,-, 16 | Kaggle Weekly,2337.11,2373.98,2241.84,2668.28,3115.03,4051.75,10715.36,2025.23,2272.58,2051.30,2025.50,3100.32 17 | Solar 10 Minutes,3.28,3.29,8.77,3.28,2.37,3.28,5.69,3.28,3.28,3.52,-,3.28 18 | Solar Weekly,1202.39,1210.83,908.65,1131.01,839.88,1044.98,1513.49,1050.84,721.59,1172.64,1996.89,576.35 19 | Electricity Hourly,845.97,846.03,574.30,1344.61,868.20,537.38,407.14,354.39,329.75,350.37,286.56,398.80 20 | Electricity Weekly,74149.18,74111.14,24347.24,67737.82,28457.18,44882.52,34518.43,27451.83,50312.05,32991.72,61429.32,76382.47 21 | Carparts,0.55,0.53,0.58,0.56,0.56,0.41,0.53,0.39,0.39,0.98,0.40,0.39 22 | FRED-MD,2798.22,3492.84,1989.97,2041.42,2957.11,8921.94,2475.68,2339.57,4264.36,2557.80,2508.40,4666.04 23 | Traffic Hourly,0.03,0.03,0.04,0.03,0.04,0.02,0.02,0.01,0.01,0.02,0.02,0.01 24 | Traffic Weekly,1.12,1.13,1.17,1.14,1.22,1.13,1.17,1.15,1.18,1.11,1.20,1.42 25 | Rideshare,6.29,7.62,6.45,6.29,3.37,6.30,6.07,6.59,6.28,5.55,2.75,6.29 26 | Hospital,21.76,18.54,17.43,17.97,19.60,19.24,19.17,22.86,18.25,20.18,19.35,36.19 27 | COVID Deaths,353.71,321.32,96.29,85.59,85.77,347.98,475.15,144.14,201.98,158.81,1049.48,408.66 28 | Temperature Rain,8.18,8.22,7.14,8.21,7.19,6.13,6.76,5.56,5.37,7.28,5.81,5.24 29 | Sunspot,4.93,4.93,2.57,4.93,2.57,3.83,2.27,7.97,0.77,14.47,0.17,0.13 30 | Saugeen River Flow,21.50,21.49,22.26,30.69,22.38,25.24,21.28,22.98,23.51,27.92,22.17,28.06 31 | US Births,1192.20,586.93,399.00,419.73,526.33,574.93,441.70,557.87,424.93,422.00,504.40,452.87 -------------------------------------------------------------------------------- /data1/paper_mae.csv: -------------------------------------------------------------------------------- 1 | Dataset,SES,Theta,TBATS,ETS,(DHR-)ARIMA,PR,CatBoost,FFNN,DeepAR,N-BEATS,WaveNet,Transformer,Last Value 2 | tourism_yearly,95579.23,90653.6,94121.08,94818.89,95033.24,82682.97,79567.22,79593.22,71471.29,70951.8,69905.47,74316.52,99456.0540551959 3 | tourism_quarterly,15014.19,7656.49,9972.42,8925.52,10475.47,9092.58,10267.97,8981.04,9511.37,8640.56,9137.12,9521.67,15845.100306204946 4 | tourism_monthly,5302.1,2069.96,2940.08,2004.51,2536.77,2187.28,2537.04,2022.21,1871.69,2003.02,2095.13,2146.98,5636.83029361023 5 | cif_2016,581875.97,714818.58,855578.4,642421.42,469059.49,563205.57,603551.3,1495923.44,3200418.0,679034.8,5998224.62,4057973.04,386526.3670424068 6 | australian_electricity_demand,659.6,665.04,370.74,1282.99,1045.92,247.18,241.77,258.76,302.41,213.83,227.5,231.45,659.600688770839 7 | dominick,5.7,5.86,7.08,5.81,7.1,8.19,8.09,5.85,5.23,8.28,5.1,5.18, 8 | bitcoin,5.33e+18,5.33e+18,9.9e+17,1.1e+18,3.62e+18,6.66e+17,1.93e+18,1.45e+18,1.95e+18,1.06e+18,2.46e+18,2.61e+18,7.777284173521224e+17 9 | pedestrian_counts,170.87,170.94,222.38,216.5,635.16,44.18,43.41,46.41,44.78,66.84,46.46,47.29,170.8838383838384 10 | vehicle_trips,29.98,30.76,21.21,30.95,30.07,27.24,22.61,22.93,22.0,28.16,24.15,28.01, 11 | kdd_cup,42.04,42.06,39.2,44.88,52.2,36.85,34.82,37.16,48.98,49.1,37.08,44.46, 12 | weather,2.24,2.51,2.3,2.35,2.45,8.17,2.51,2.09,2.02,2.34,2.29,2.03,2.362190193902301 13 | nn5_daily,6.63,3.8,3.7,3.72,4.41,5.47,4.22,4.06,3.94,4.92,3.97,4.16,8.262752532958984 14 | nn5_weekly,15.66,15.3,14.98,15.7,15.38,14.94,15.29,15.02,14.69,14.19,19.34,20.34,16.708553516113007 15 | kaggle_daily,363.43,358.73,415.4,403.23,340.36,,,,,,,, 16 | kaggle_web_traffic_weekly,2337.11,2373.98,2241.84,2668.28,3115.03,4051.75,10715.36,2025.23,2272.58,2051.3,2025.5,3100.32,2081.781183003247 17 | solar_10_minutes,3.28,3.29,8.77,3.28,2.37,3.28,5.69,3.28,3.28,3.52,,3.28,2.7269221451758905 18 | solar_weekly,1202.39,1210.83,908.65,1131.01,839.88,1044.98,1513.49,1050.84,721.59,1172.64,1996.89,576.35,1729.4092503457175 19 | electricity_hourly,845.97,846.03,574.3,1344.61,868.2,537.38,407.14,354.39,329.75,350.37,286.56,398.8, 20 | electricity_weekly,74149.18,74111.14,24347.24,67737.82,28457.18,44882.52,34518.43,27451.83,50312.05,32991.72,61429.32,76382.47, 21 | carparts,0.55,0.53,0.58,0.56,0.56,0.41,0.53,0.39,0.39,0.98,0.4,0.39, 22 | fred_md,2798.22,3492.84,1989.97,2041.42,2957.11,8921.94,2475.68,2339.57,4264.36,2557.8,2508.4,4666.04,2825.672461360778 23 | traffic_hourly,0.03,0.03,0.04,0.03,0.04,0.02,0.02,0.01,0.01,0.02,0.02,0.01,0.0262463919825758 24 | traffic_weekly,1.12,1.13,1.17,1.14,1.22,1.13,1.17,1.15,1.18,1.11,1.2,1.42,1.1855844384623289 25 | rideshare,6.29,7.62,6.45,6.29,3.37,6.3,6.07,6.59,6.28,5.55,2.75,6.29, 26 | hospital,21.76,18.54,17.43,17.97,19.6,19.24,19.17,22.86,18.25,20.18,19.35,36.19,24.06573229030856 27 | covid_deaths,353.71,321.32,96.29,85.59,85.77,347.98,475.15,144.14,201.98,158.81,1049.48,408.66,353.70939849624057 28 | temperature_rain,8.18,8.22,7.14,8.21,7.19,6.13,6.76,5.56,5.37,7.28,5.81,5.24, 29 | sunspot,4.93,4.93,2.57,4.93,2.57,3.83,2.27,7.97,0.77,14.47,0.17,0.13,3.933333396911621 30 | saugeenday,21.5,21.49,22.26,30.69,22.38,25.24,21.28,22.98,23.51,27.92,22.17,28.06,21.496667098999023 31 | us_births,1192.2,586.93,399.0,419.73,526.33,574.93,441.7,557.87,424.93,422.0,504.4,452.87,1152.6666666666667 32 | -------------------------------------------------------------------------------- /data1/small_context.py: -------------------------------------------------------------------------------- 1 | import darts.datasets 2 | import pandas as pd 3 | 4 | dataset_names = [ 5 | 'AirPassengersDataset', 6 | 'AusBeerDataset', 7 | 'AustralianTourismDataset', 8 | 'ETTh1Dataset', 9 | 'ETTh2Dataset', 10 | 'ETTm1Dataset', 11 | 'ETTm2Dataset', 12 | 'ElectricityDataset', 13 | 'EnergyDataset', 14 | 'ExchangeRateDataset', 15 | 'GasRateCO2Dataset', 16 | 'HeartRateDataset', 17 | 'ILINetDataset', 18 | 'IceCreamHeaterDataset', 19 | 'MonthlyMilkDataset', 20 | 'MonthlyMilkIncompleteDataset', 21 | 'SunspotsDataset', 22 | 'TaylorDataset', 23 | 'TemperatureDataset', 24 | 'TrafficDataset', 25 | 'USGasolineDataset', 26 | 'UberTLCDataset', 27 | 'WeatherDataset', 28 | 'WineDataset', 29 | 'WoolyDataset', 30 | ] 31 | 32 | def get_descriptions(w_references=False): 33 | descriptions = [] 34 | for dsname in dataset_names: 35 | d = getattr(darts.datasets,dsname)().__doc__ 36 | 37 | if w_references: 38 | descriptions.append(d) 39 | continue 40 | 41 | lines = [] 42 | for l in d.split("\n"): 43 | if l.strip().startswith("References"): 44 | break 45 | if l.strip().startswith("Source"): 46 | break 47 | if l.strip().startswith("Obtained"): 48 | break 49 | lines.append(l) 50 | 51 | d = " ".join([x.strip() for x in lines]).strip() 52 | 53 | descriptions.append(d) 54 | 55 | return dict(zip(dataset_names,descriptions)) 56 | 57 | # def get_all_datasets(func): 58 | # def wrapper(*args, **kwargs): 59 | # if args[0] == 'all': 60 | # datasets_dict = {} 61 | # for name in dataset_names: 62 | # return datasets_dict.update({name: func(name)}) 63 | # else: 64 | # return func(*args) 65 | # return wrapper 66 | 67 | # @get_all_datasets 68 | def get_dataset(dsname): 69 | darts_ds = getattr(darts.datasets,dsname)().load() 70 | if dsname=='GasRateCO2Dataset': 71 | darts_ds = darts_ds[darts_ds.columns[1]] 72 | series = darts_ds.pd_series() 73 | 74 | if dsname == 'SunspotsDataset': 75 | series = series.iloc[::4] 76 | if dsname =='HeartRateDataset': 77 | series = series.iloc[::2] 78 | return series 79 | 80 | 81 | 82 | 83 | def get_datasets(n=-1,testfrac=0.2): 84 | datasets = [ 85 | 'AirPassengersDataset', 86 | 'AusBeerDataset', 87 | 'GasRateCO2Dataset', # multivariate 88 | 'MonthlyMilkDataset', 89 | 'SunspotsDataset', #very big, need to subsample? 90 | 'WineDataset', 91 | 'WoolyDataset', 92 | 'HeartRateDataset', 93 | ] 94 | datas = [] 95 | for i,dsname in enumerate(datasets): 96 | series = get_dataset(dsname) 97 | splitpoint = int(len(series)*(1-testfrac)) 98 | 99 | train = series.iloc[:splitpoint] # Only test the last couples of samples 100 | test = series.iloc[splitpoint:] 101 | datas.append((train,test)) 102 | if i+1==n: 103 | break 104 | return dict(zip(datasets,datas)) 105 | 106 | def get_memorization_datasets(n=-1,testfrac=0.15, predict_steps=30): 107 | datasets = [ 108 | 'IstanbulTraffic', 109 | 'TSMCStock', 110 | 'TurkeyPower' 111 | ] 112 | datas = [] 113 | for i,dsname in enumerate(datasets): 114 | with open(f"datasets/memorization/{dsname}.csv") as f: 115 | series = pd.read_csv(f, index_col=0, parse_dates=True).values.reshape(-1) 116 | # treat as float 117 | series = series.astype(float) 118 | series = pd.Series(series) 119 | if predict_steps is not None: 120 | splitpoint = len(series)-predict_steps 121 | else: 122 | splitpoint = int(len(series)*(1-testfrac)) 123 | train = series.iloc[:splitpoint] 124 | test = series.iloc[splitpoint:] 125 | datas.append((train,test)) 126 | if i+1==n: 127 | break 128 | return dict(zip(datasets,datas)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | The official implementation of the paper "[Time series forecasting with llms: Understanding and enhancing model capabilities](https://dl.acm.org/doi/10.1145/3715073.3715083)", which is accepted by SIGKDD Explorations Newsletter. 2 | 3 | # Time-Series-Forecasting-with-LLMs 4 | 5 | Large language models (LLMs) have been applied in many fields and have developed rapidly in recent years. As a classic machine learning task, time series forecasting has recently been boosted by LLMs. Recent works treat large language models as \emph{zero-shot} time series reasoners without further fine-tuning, which achieves remarkable performance. However, some unexplored research problems exist when applying LLMs for time series forecasting under the zero-shot setting. For instance, the LLMs' preferences for the input time series are less understood. In this paper, by comparing LLMs with traditional time series forecasting models, we observe many interesting properties of LLMs in the context of time series forecasting. First, our study shows that LLMs perform well in predicting time series with clear patterns and trends but face challenges with datasets lacking periodicity. This observation can be explained by the ability of LLMs to recognize the underlying period within datasets, which is supported by our experiments. In addition, the input strategy is investigated, and it is found that incorporating external knowledge and adopting natural language paraphrases substantially improve the predictive performance of LLMs for time series. Our study contributes insight into LLMs' advantages and limitations in time series forecasting under different conditions. 6 | 7 | ![Workflow](Images/Workflow.png) 8 | 9 | ## Brief Information for each file: 10 | - 0_baseline_experiment.ipynb and 0_baseline_experiment_w_gemini.ipynb 11 | These notebooks implement baseline predictions using pre-trained large language models (LLMs), including Gemini-based variants. 12 | - 1_stl_decomposition.ipynb 13 | This notebook applies Seasonal-Trend decomposition using LOESS (STL) to evaluate the strength of trend and seasonal components in time series data. 14 | - 2_model_preference_analysis.ipynb 15 | This notebook analyzes the preferences of LLMs when applied to various datasets by comparing their performance on both real-world and synthetic data. 16 | - 3_Period_prediction_w_GEMINI.ipynb 17 | This notebook performs period prediction using the Gemini model. 18 | - 4_counterfactual_analysis.ipynb 19 | This notebook conducts counterfactual analysis to assess model behavior under hypothetical scenarios. 20 | - 5_impact_of_sequence_length.ipynb 21 | This notebook investigates how varying input sequence lengths affect model predictions. 22 | - 6_paraphrase_and_predict.ipynb 23 | This notebook explores a method that combines input paraphrasing with prediction to evaluate the consistency and robustness of the model. 24 | 25 | ## Getting Started 26 | ### Installation 27 | 28 | 1. Create a new Conda environment: 29 | 30 | ```bash 31 | conda create -n time_llm python=3.10 32 | ``` 33 | 34 | 2. Activate the environment: 35 | 36 | ```bash 37 | conda activate time_llm 38 | ``` 39 | 40 | 3. Install the required packages: 41 | 42 | ```bash 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | ### Configuration 47 | Add your `api_key` and `api_base` to the `config.json` file before running any notebooks. 48 | 49 | ## Datasets 50 | Please create the following directory structure: 51 | 52 | ``` 53 | datasets/ 54 | ├── memorization/ 55 | ├── monash/ 56 | └── synthetic/ 57 | ``` 58 | 59 | Place the corresponding time series datasets referenced in the paper into these directories. 60 | 61 | 62 | ## Acknowledgement 63 | 64 | We would like to acknowledge the following GitHub repository for providing a valuable code base: 65 | [Large Language Models Are Zero-Shot Time Series Forecasters](https://github.com/ngruver/llmtime) 66 | 67 | ## Citation 68 | If you find the code is vailuable, please use this citation. 69 | ``` 70 | @article{tang2025time, 71 | title={Time series forecasting with llms: Understanding and enhancing model capabilities}, 72 | author={Tang, Hua and Zhang, Chong and Jin, Mingyu and Yu, Qinkai and Wang, Zhenting and Jin, Xiaobo and Zhang, Yongfeng and Du, Mengnan}, 73 | journal={ACM SIGKDD Explorations Newsletter}, 74 | volume={26}, 75 | number={2}, 76 | pages={109--118}, 77 | year={2025}, 78 | publisher={ACM New York, NY, USA} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /models/llms.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from models.gpt import gpt_completion_fn, gpt_nll_fn, gemini_completion_fn 3 | from models.gpt import tokenize_fn as gpt_tokenize_fn 4 | from models.llama import llama_completion_fn, llama_nll_fn 5 | from models.llama import tokenize_fn as llama_tokenize_fn 6 | 7 | # Required: Text completion function for each model 8 | # ----------------------------------------------- 9 | # Each model is mapped to a function that samples text completions. 10 | # The completion function should follow this signature: 11 | # 12 | # Args: 13 | # - input_str (str): String representation of the input time series. 14 | # - steps (int): Number of steps to predict. 15 | # - settings (SerializerSettings): Serialization settings. 16 | # - num_samples (int): Number of completions to sample. 17 | # - temp (float): Temperature parameter for model's output randomness. 18 | # 19 | # Returns: 20 | # - list: Sampled completion strings from the model. 21 | completion_fns = { 22 | 'text-davinci-003': partial(gpt_completion_fn, model='text-davinci-003'), 23 | 'gpt-4': partial(gpt_completion_fn, model='gpt-4'), 24 | 'gpt-3.5-turbo-instruct': partial(gpt_completion_fn, model='gpt-3.5-turbo-instruct'), 25 | 'llama-7b': partial(llama_completion_fn, model='7b'), 26 | 'llama-13b': partial(llama_completion_fn, model='13b'), 27 | 'llama-70b': partial(llama_completion_fn, model='70b'), 28 | 'llama-7b-chat': partial(llama_completion_fn, model='7b-chat'), 29 | 'llama-13b-chat': partial(llama_completion_fn, model='13b-chat'), 30 | 'llama-70b-chat': partial(llama_completion_fn, model='70b-chat'), 31 | 'gemini-1.0-pro': partial(gemini_completion_fn, model='gemini-1.0-pro'), 32 | 'gemini-pro': partial(gemini_completion_fn, model='gemini-pro'), 33 | } 34 | 35 | # Optional: NLL/D functions for each model 36 | # ----------------------------------------------- 37 | # Each model is mapped to a function tgpt_completion_fnhat computes the continuous Negative Log-Likelihood 38 | # per Dimension (NLL/D). This is used for computing likelihoods only and not needed for sampling. 39 | # 40 | # The NLL function should follow this signature: 41 | # 42 | # Args: 43 | # - input_arr (np.ndarray): Input time series (history) after data1 transformation. 44 | # - target_arr (np.ndarray): Ground truth series (future) after data1 transformation. 45 | # - settings (SerializerSettings): Serialization settings. 46 | # - transform (callable): Data transformation function (e.g., scaling) for determining the Jacobian factor. 47 | # - count_seps (bool): If True, count time step separators in NLL computation, required if allowing variable number of digits. 48 | # - temp (float): Temperature parameter for sampling. 49 | # 50 | # Returns: 51 | # - float: Computed NLL per dimension for p(target_arr | input_arr). 52 | nll_fns = { 53 | 'text-davinci-003': partial(gpt_nll_fn, model='text-davinci-003'), 54 | 'llama-7b': partial(llama_nll_fn, model='7b'), 55 | 'llama-13b': partial(llama_nll_fn, model='13b'), 56 | 'llama-70b': partial(llama_nll_fn, model='70b'), 57 | 'llama-7b-chat': partial(llama_nll_fn, model='7b-chat'), 58 | 'llama-13b-chat': partial(llama_nll_fn, model='13b-chat'), 59 | 'llama-70b-chat': partial(llama_nll_fn, model='70b-chat'), 60 | } 61 | 62 | # Optional: Tokenization function for each model, only needed if you want automatic input truncation. 63 | # The tokenization function should follow this signature: 64 | # 65 | # Args: 66 | # - str (str): A string to tokenize. 67 | # Returns: 68 | # - token_ids (list): A list of token ids. 69 | tokenization_fns = { 70 | 'text-davinci-003': partial(gpt_tokenize_fn, model='text-davinci-003'), 71 | 'gpt-3.5-turbo-instruct': partial(gpt_tokenize_fn, model='gpt-3.5-turbo-instruct'), 72 | 'llama-7b': partial(llama_tokenize_fn, model='7b'), 73 | 'llama-13b': partial(llama_tokenize_fn, model='13b'), 74 | 'llama-70b': partial(llama_tokenize_fn, model='70b'), 75 | 'llama-7b-chat': partial(llama_tokenize_fn, model='7b-chat'), 76 | 'llama-13b-chat': partial(llama_tokenize_fn, model='13b-chat'), 77 | 'llama-70b-chat': partial(llama_tokenize_fn, model='70b-chat'), 78 | } 79 | 80 | # Optional: Context lengths for each model, only needed if you want automatic input truncation. 81 | context_lengths = { 82 | 'text-davinci-003': 4097, 83 | 'gpt-3.5-turbo-instruct': 4097, 84 | 'llama-7b': 4096, 85 | 'llama-13b': 4096, 86 | 'llama-70b': 4096, 87 | 'llama-7b-chat': 4096, 88 | 'llama-13b-chat': 4096, 89 | 'llama-70b-chat': 4096, 90 | 'gemini-1.0-pro': 4096, 91 | 'gemini-pro': 4096 92 | } -------------------------------------------------------------------------------- /models/gaussian_process.py: -------------------------------------------------------------------------------- 1 | import gpytorch 2 | from gpytorch.kernels import SpectralMixtureKernel, RBFKernel, ScaleKernel, MaternKernel 3 | import torch 4 | from tqdm.auto import tqdm 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn.preprocessing import MinMaxScaler 8 | 9 | class SpectralMixtureGPModel(gpytorch.models.ExactGP): 10 | def __init__(self, train_x, train_y, likelihood): 11 | super().__init__(train_x, train_y, likelihood) 12 | self.mean_module = gpytorch.means.ConstantMean() 13 | covar = SpectralMixtureKernel(num_mixtures=12) 14 | covar.initialize_from_data(train_x, train_y) 15 | self.covar_module = ScaleKernel(covar)+RBFKernel()#+ScaleKernel(MaternKernel) 16 | #self.covar_module.initialize_from_data(train_x, train_y) 17 | 18 | def forward(self, x): 19 | mean_x = self.mean_module(x) 20 | covar_x = self.covar_module(x) 21 | return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) 22 | 23 | def train_gp(x,y, epochs=300, lr=0.05): 24 | train_x = torch.tensor(x, dtype=torch.float32).unsqueeze(-1) 25 | train_y = torch.tensor(y, dtype=torch.float32) 26 | 27 | likelihood = gpytorch.likelihoods.GaussianLikelihood() 28 | model = SpectralMixtureGPModel(train_x, train_y, likelihood) 29 | 30 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 31 | mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) 32 | 33 | model.train() 34 | likelihood.train() 35 | 36 | for epoch in tqdm(range(epochs)): 37 | optimizer.zero_grad() 38 | output = model(train_x) 39 | loss = -mll(output, train_y) 40 | loss.backward() 41 | optimizer.step() 42 | 43 | model.eval() 44 | likelihood.eval() 45 | 46 | return model, likelihood 47 | 48 | def test_gp(model, likelihood, test_x,test_y): 49 | test_x = torch.tensor(test_x, dtype=torch.float32).unsqueeze(-1) 50 | test_y = torch.tensor(test_y, dtype=torch.float32) 51 | 52 | with torch.no_grad(), gpytorch.settings.fast_pred_var(): 53 | preds = likelihood(model(test_x)) 54 | preds_mean = preds.mean 55 | 56 | rmse = torch.sqrt(torch.mean((preds_mean - test_y) ** 2)).item() 57 | return preds_mean, rmse 58 | 59 | def get_gp_predictions_data(train, test, epochs=300, lr=0.05, num_samples=100, **kwargs): 60 | train = train.copy() 61 | test = test.copy() 62 | num_samples = max(1,num_samples) 63 | if not isinstance(train, list): 64 | # Assume single train/test case 65 | train = [train] 66 | test = [test] 67 | 68 | for i in range(len(train)): 69 | if not isinstance(train[i], pd.Series): 70 | train[i] = pd.Series(train[i], index = pd.RangeIndex(len(train[i]))) 71 | test[i] = pd.Series(test[i], index = pd.RangeIndex(len(train[i]), len(test[i])+len(train[i]))) 72 | 73 | test_len = len(test[0]) 74 | assert all(len(t)==test_len for t in test), f'All test series must have same length, got {[len(t) for t in test]}' 75 | 76 | gp_models = [] 77 | gp_likelihoods = [] 78 | BPD_list = [] 79 | gp_mean_list = [] 80 | f_samples_list = [] 81 | 82 | for train_series, test_series in zip(train, test): 83 | # Normalize series 84 | scaler = MinMaxScaler() 85 | train_y = scaler.fit_transform(train_series.values.reshape(-1,1)).reshape(-1) 86 | test_y = scaler.transform(test_series.values.reshape(-1,1)).reshape(-1) 87 | 88 | all_t = np.linspace(0, 1, train_series.shape[0]+test_series.shape[0]) 89 | train_x = all_t[:train_series.shape[0]] 90 | test_x = all_t[train_series.shape[0]:] 91 | 92 | # Train the GP model 93 | gp_model, gp_likelihood = train_gp(train_x, train_y, epochs=epochs, lr=lr) 94 | gp_models.append(gp_model) 95 | gp_likelihoods.append(gp_likelihood) 96 | 97 | # Test the GP model 98 | with torch.no_grad(): 99 | observed_pred = gp_likelihood(gp_model(torch.tensor(test_x, dtype=torch.float32).unsqueeze(-1))) 100 | BPD = -observed_pred.log_prob(torch.tensor(test_y, dtype=torch.float32))/(test_y.shape[0]) 101 | BPD -= np.log(scaler.scale_) 102 | BPD_list.append(BPD.cpu().data.item()) 103 | 104 | gp_mean = observed_pred.mean.numpy() 105 | gp_mean = scaler.inverse_transform(gp_mean.reshape(-1,1)).reshape(-1) 106 | 107 | 108 | f_samples = observed_pred.sample(sample_shape=torch.Size([num_samples])).numpy() 109 | f_samples = scaler.inverse_transform(f_samples) 110 | 111 | if isinstance(train, pd.Series): 112 | gp_mean = pd.Series(gp_mean, index=test.index) 113 | f_samples = pd.DataFrame(f_samples, columns=test.index) 114 | 115 | gp_mean_list.append(gp_mean) 116 | f_samples_list.append(f_samples) 117 | 118 | out_dict = { 119 | 'NLL/D': np.mean(BPD_list), 120 | 'median': gp_mean_list if len(gp_mean_list)>1 else gp_mean_list[0], 121 | 'samples': f_samples_list if len(f_samples_list)>1 else f_samples_list[0], 122 | 'info': {'Method': 'Gaussian Process','epochs':epochs, 'lr':lr} 123 | } 124 | 125 | return out_dict 126 | -------------------------------------------------------------------------------- /data1/monash.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from collections import defaultdict 4 | from sklearn.preprocessing import StandardScaler 5 | import datasets 6 | from datasets import load_dataset 7 | import os 8 | import pickle 9 | 10 | fix_pred_len = { 11 | 'australian_electricity_demand': 336, 12 | 'pedestrian_counts': 24, 13 | 'traffic_hourly': 168, 14 | } 15 | 16 | def get_benchmark_test_sets(): 17 | test_set_dir = "datasets/monash" 18 | if not os.path.exists(test_set_dir): 19 | os.makedirs(test_set_dir) 20 | 21 | if len(os.listdir(test_set_dir)) > 0: 22 | print(f'Loading test sets from {test_set_dir}') 23 | test_sets = {} 24 | for file in os.listdir(test_set_dir): 25 | test_sets[file.split(".")[0]] = pickle.load(open(os.path.join(test_set_dir, file), 'rb')) 26 | return test_sets 27 | else: 28 | print(f'No files found in {test_set_dir}. You are not using our preprocessed datasets!') 29 | 30 | benchmarks = { 31 | "monash_tsf": datasets.get_dataset_config_names("monash_tsf"), 32 | } 33 | 34 | test_sets = defaultdict(list) 35 | for path in benchmarks: 36 | pred_lens = [24, 48, 96, 192] if path == "ett" else [None] 37 | for name in benchmarks[path]: 38 | for pred_len in pred_lens: 39 | if pred_len is None: 40 | ds = load_dataset(path, name) 41 | else: 42 | ds = load_dataset(path, name, multivariate=False, prediction_length=pred_len) 43 | 44 | train_example = ds['train'][0]['target'] 45 | val_example = ds['validation'][0]['target'] 46 | 47 | if len(np.array(train_example).shape) > 1: 48 | print(f"Skipping {name} because it is multivariate") 49 | continue 50 | 51 | pred_len = len(val_example) - len(train_example) 52 | if name in fix_pred_len: 53 | print(f"Fixing pred len for {name}: {pred_len} -> {fix_pred_len[name]}") 54 | pred_len = fix_pred_len[name] 55 | 56 | tag = name 57 | print("Processing", tag) 58 | 59 | pairs = [] 60 | for x in ds['test']: 61 | if np.isnan(x['target']).any(): 62 | print(f"Skipping {name} because it has NaNs") 63 | break 64 | history = np.array(x['target'][:-pred_len]) 65 | target = np.array(x['target'][-pred_len:]) 66 | pairs.append((history, target)) 67 | else: 68 | scaler = None 69 | if path == "ett": 70 | trainset = np.array(ds['train'][0]['target']) 71 | scaler = StandardScaler().fit(trainset[:,None]) 72 | test_sets[tag] = (pairs, scaler) 73 | 74 | for name in test_sets: 75 | try: 76 | with open(os.path.join(test_set_dir,f"{name}.pkl"), 'wb') as f: 77 | pickle.dump(test_sets[name], f) 78 | print(f"Saved {name}") 79 | except: 80 | print(f"Failed to save {name}") 81 | 82 | return test_sets 83 | 84 | def get_datasets(): 85 | benchmarks = get_benchmark_test_sets() 86 | # shuffle the benchmarks 87 | for k, v in benchmarks.items(): 88 | x, _scaler = v # scaler is not used 89 | train, test = zip(*x) 90 | np.random.seed(0) 91 | ind = np.arange(len(train)) 92 | ind = np.random.permutation(ind) 93 | train = [train[i] for i in ind] 94 | test = [test[i] for i in ind] 95 | benchmarks[k] = [list(train), list(test)] 96 | 97 | df = pd.read_csv('data1/last_val_mae.csv') 98 | df.sort_values(by='mae') 99 | 100 | df_paper = pd.read_csv('data1/paper_mae_raw.csv') # pdf text -> csv 101 | datasets = df_paper['Dataset'] 102 | name_map = { 103 | 'Aus. Electricity Demand' :'australian_electricity_demand', 104 | 'Kaggle Weekly': 'kaggle_web_traffic_weekly', 105 | 'FRED-MD': 'fred_md', 106 | 'Saugeen River Flow': 'saugeenday', 107 | 108 | } 109 | datasets = [name_map.get(d, d) for d in datasets] 110 | # lower case and repalce spaces with underscores 111 | datasets = [d.lower().replace(' ', '_') for d in datasets] 112 | df_paper['Dataset'] = datasets 113 | df_paper = df_paper.reset_index(drop=True) 114 | # for each dataset, add last value mae to df_paper 115 | for dataset in df_paper['Dataset']: 116 | if dataset in df['dataset'].values: 117 | df_paper.loc[df_paper['Dataset'] == dataset, 'Last Value'] = df[df['dataset'] == dataset]['mae'].values[0] 118 | # turn '-' into np.nan 119 | df_paper = df_paper.replace('-', np.nan) 120 | # convert all values to float 121 | for method in df_paper.columns[1:]: 122 | df_paper[method] = df_paper[method].astype(float) 123 | df_paper.to_csv('data1/paper_mae.csv', index=False) 124 | # normalize each method by dividing by last value mae 125 | for method in df_paper.columns[1:-1]: # skip dataset and last value 126 | df_paper[method] = df_paper[method] / df_paper['Last Value'] 127 | # sort df by minimum mae across methods 128 | df_paper['normalized_min'] = df_paper[df_paper.columns[1:-1]].min(axis=1) 129 | df_paper['normalized_median'] = df_paper[df_paper.columns[1:-1]].median(axis=1) 130 | df_paper = df_paper.sort_values(by='normalized_min') 131 | df_paper = df_paper.reset_index(drop=True) 132 | # save as csv 133 | df_paper.to_csv('data1/paper_mae_normalized.csv', index=False) 134 | return benchmarks 135 | 136 | def main(): 137 | get_datasets() 138 | 139 | if __name__ == "__main__": 140 | main() -------------------------------------------------------------------------------- /data1/paper_mae_normalized.csv: -------------------------------------------------------------------------------- 1 | Dataset,SES,Theta,TBATS,ETS,(DHR-)ARIMA,PR,CatBoost,FFNN,DeepAR,N-BEATS,WaveNet,Transformer,Last Value,normalized_min,normalized_median 2 | sunspot,1.2533898102487173,1.2533898102487173,0.6533898199471001,1.2533898102487173,0.6533898199471001,0.9737287978199975,0.5771186347392675,2.026271153688089,0.19576270870010395,3.6788134998577977,0.04322033828443854,0.033050846923394175,3.933333396911621,0.033050846923394175,0.9737287978199975 3 | covid_deaths,1.000001700559165,0.9084293529266092,0.2722291248391112,0.24197830299075218,0.24248719532091148,0.9838019613824273,1.34333439263998,0.4075096692731279,0.5710337380309863,0.44898439418111175,2.96706845919774,1.1553552202383548,353.70939849624057,0.24197830299075218,0.9084293529266092 4 | pedestrian_counts,0.9999190187675484,1.0003286537608984,1.3013518545884437,1.266942515147037,3.7169108910891087,0.25853820008866557,0.2540322151618147,0.27158800059110383,0.26204935717452343,0.3911428993645633,0.27188059701492534,0.27673769765036205,170.8838383838384,0.2540322151618147,0.3911428993645633 5 | australian_electricity_demand,0.9999989557760464,1.0082463698443023,0.5620673330267002,1.9451010616602633,1.58568664012323,0.3747418767263844,0.36653994472100476,0.3922979529966794,0.458474354481859,0.32418098349543967,0.34490564347945807,0.35089411509151897,659.600688770839,0.32418098349543967,0.458474354481859 6 | tourism_monthly,0.9406172837969468,0.3672205640724105,0.5215839127413151,0.3556094286308854,0.4500348365775033,0.3880336795804277,0.4500827358375371,0.3587494912331008,0.33204654078759493,0.3553450956773656,0.37168583953556084,0.38088427150871706,5636.83029361023,0.33204654078759493,0.38088427150871706 7 | solar_weekly,0.6952605346361114,0.7001408138403036,0.5254106278304898,0.6539863249684281,0.48564560403045365,0.6042410145493922,0.8751485512740526,0.6076294548499331,0.4172465249944457,0.6780581286734667,1.1546659644620334,0.3332640899687479,1729.4092503457175,0.3332640899687479,0.6539863249684281 8 | us_births,1.0342972816657028,0.5091931752458068,0.34615384615384615,0.3641382301908618,0.45661943319838055,0.4987825332562174,0.3831983805668016,0.4839820705610179,0.36864950838635047,0.3661075766338924,0.437593984962406,0.39288895315211103,1152.6666666666667,0.34615384615384615,0.437593984962406 9 | traffic_hourly,1.1430142482028047,1.1430142482028047,1.5240189976037397,1.1430142482028047,1.5240189976037397,0.7620094988018699,0.7620094988018699,0.38100474940093493,0.38100474940093493,0.7620094988018699,0.7620094988018699,0.38100474940093493,0.0262463919825758,0.38100474940093493,0.7620094988018699 10 | nn5_daily,0.802396050656708,0.4598951723220951,0.44779266778730314,0.45021316869426153,0.5337204499843262,0.6620069980531211,0.5107256913682214,0.4913616841125542,0.4768386786708039,0.5954432231117652,0.4804694300312415,0.5034641886473462,8.262752532958984,0.44779266778730314,0.5034641886473462 11 | tourism_quarterly,0.9475604262423281,0.4832086797835995,0.6293693196814159,0.5632984220683516,0.6611173042494279,0.5738417444059564,0.6480217733919336,0.5668023443488726,0.6002719967809447,0.5453143137640065,0.5766527079933916,0.6009220399994131,15845.100306204946,0.4832086797835995,0.6002719967809447 12 | tourism_yearly,0.961019727838344,0.9114940348396415,0.9463584785674776,0.9533747432547206,0.9555299665041875,0.8313518044272377,0.800023897548177,0.800285319542514,0.7186218142169105,0.7133985022231359,0.702877976248726,0.7472297257918155,99456.0540551959,0.702877976248726,0.8313518044272377 13 | fred_md,0.9902846272042592,1.2361092970831906,0.7042465208588531,0.7224545760045026,1.046515489829959,3.1574572502658014,0.876138347191086,0.8279692823538782,1.5091487277143167,0.9052004558122859,0.8877178916879889,1.651302500132285,2825.672461360778,0.7042465208588531,0.9902846272042592 14 | hospital,0.9041902293894837,0.7703900208125471,0.7242663464273299,0.746704890722841,0.8144360522074393,0.7994770226770985,0.7965683224906435,0.9498983751766358,0.7583396914686615,0.8385367108952104,0.8040478372558139,1.5037979963973074,24.06573229030856,0.7242663464273299,0.8040478372558139 15 | nn5_weekly,0.937244506826888,0.9156986560952354,0.8965467887782109,0.939638490241516,0.9204866229244916,0.8941528053635828,0.9151001602415783,0.8989407721928389,0.8791904090221573,0.8492656163393065,1.1574909809726701,1.2173405663383716,16.708553516113007,0.8492656163393065,0.9156986560952354 16 | weather,0.9482724997260088,1.0625732028179828,0.9736726559686696,0.9948394528375538,1.0371730465753222,3.4586546083756655,1.0625732028179828,0.8847721091193562,0.8551385935029185,0.9906060934637768,0.9694392965948928,0.8593719528766953,2.362190193902301,0.8551385935029185,0.9906060934637768 17 | bitcoin,6.853292076103736,6.853292076103736,1.2729379278316508,1.414375475368501,4.654581109849067,0.8563400605412924,2.48158606132837,1.864404035713024,2.5073019790623428,1.3629436399005554,3.1630578812786476,3.355927264283443,7.777284173521224e+17,0.8563400605412924,2.5073019790623428 18 | solar_10_minutes,1.2028212854564042,1.20648842352182,3.216080083369715,1.2028212854564042,0.8691117215035604,1.2028212854564042,2.086601559221628,1.2028212854564042,1.2028212854564042,1.2908325990263851,,1.2028212854564042,2.7269221451758905,0.8691117215035604,1.2028212854564042 19 | traffic_weekly,0.9446817650985787,0.9531164237155302,0.9868550581833365,0.9615510823324818,1.0290283512680944,0.9531164237155302,0.9868550581833365,0.9699857409494334,0.9952897168002881,0.9362471064816271,1.0121590340341913,1.1977215236071264,1.1855844384623289,0.9362471064816271,0.9868550581833365 20 | kaggle_web_traffic_weekly,1.1226492097639231,1.1403600048758329,1.0768855143391425,1.2817293295689465,1.4963292133835864,1.9462900486759183,5.147207635214411,0.9728351935040241,1.0916517156339651,0.9853581234895814,0.9729648901321829,1.489263148938341,2081.781183003247,0.9728351935040241,1.1403600048758329 21 | saugeenday,1.0001550426857162,0.9996898542937693,1.0355093604736765,1.4276631748848667,1.0410916211770385,1.1741355012738361,0.9899208980628857,1.0690029246938493,1.0936579094670322,1.2988059903155906,1.0313226649461549,1.3053186278028464,21.496667098999023,0.9899208980628857,1.0690029246938493 22 | cif_2016,1.5053978709197886,1.8493397629496653,2.2135059156420556,1.6620377670885214,1.2135252080967047,1.4570948272157827,1.5614751061310725,3.870171785294736,8.279947431500512,1.7567619130249434,15.51828059207645,10.498567202673625,386526.3670424068,1.2135252080967047,1.8493397629496653 23 | dominick,,,,,,,,,,,,,,, 24 | vehicle_trips,,,,,,,,,,,,,,, 25 | kdd_cup,,,,,,,,,,,,,,, 26 | kaggle_daily,,,,,,,,,,,,,,, 27 | electricity_hourly,,,,,,,,,,,,,,, 28 | electricity_weekly,,,,,,,,,,,,,,, 29 | carparts,,,,,,,,,,,,,,, 30 | rideshare,,,,,,,,,,,,,,, 31 | temperature_rain,,,,,,,,,,,,,,, 32 | -------------------------------------------------------------------------------- /models/llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from jax import grad,vmap 4 | from tqdm import tqdm 5 | import argparse 6 | from transformers import ( 7 | LlamaForCausalLM, 8 | LlamaTokenizer, 9 | ) 10 | from data1.serialize import serialize_arr, deserialize_str, SerializerSettings 11 | 12 | DEFAULT_EOS_TOKEN = "" 13 | DEFAULT_BOS_TOKEN = "" 14 | DEFAULT_UNK_TOKEN = "" 15 | 16 | loaded = {} 17 | 18 | def llama2_model_string(model_size, chat): 19 | chat = "chat-" if chat else "" 20 | return f"meta-llama/Llama-2-{model_size.lower()}-{chat}hf" 21 | 22 | def get_tokenizer(model): 23 | name_parts = model.split("-") 24 | model_size = name_parts[0] 25 | chat = len(name_parts) > 1 26 | assert model_size in ["7b", "13b", "70b"] 27 | 28 | tokenizer = LlamaTokenizer.from_pretrained( 29 | llama2_model_string(model_size, chat), 30 | use_fast=False, 31 | ) 32 | 33 | special_tokens_dict = dict() 34 | if tokenizer.eos_token is None: 35 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 36 | if tokenizer.bos_token is None: 37 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 38 | if tokenizer.unk_token is None: 39 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 40 | 41 | tokenizer.add_special_tokens(special_tokens_dict) 42 | tokenizer.pad_token = tokenizer.eos_token 43 | 44 | return tokenizer 45 | 46 | def get_model_and_tokenizer(model_name, cache_model=False): 47 | if model_name in loaded: 48 | return loaded[model_name] 49 | name_parts = model_name.split("-") 50 | model_size = name_parts[0] 51 | chat = len(name_parts) > 1 52 | 53 | assert model_size in ["7b", "13b", "70b"] 54 | 55 | tokenizer = get_tokenizer(model_name) 56 | 57 | model = LlamaForCausalLM.from_pretrained( 58 | llama2_model_string(model_size, chat), 59 | device_map="auto", 60 | torch_dtype=torch.float16, 61 | ) 62 | model.eval() 63 | if cache_model: 64 | loaded[model_name] = model, tokenizer 65 | return model, tokenizer 66 | 67 | def tokenize_fn(str, model): 68 | tokenizer = get_tokenizer(model) 69 | return tokenizer(str) 70 | 71 | def llama_nll_fn(model, input_arr, target_arr, settings:SerializerSettings, transform, count_seps=True, temp=1, cache_model=True): 72 | """ Returns the NLL/dimension (log base e) of the target array (continuous) according to the LM 73 | conditioned on the input array. Applies relevant log determinant for transforms and 74 | converts from discrete NLL of the LLM to continuous by assuming uniform within the bins. 75 | inputs: 76 | input_arr: (n,) context array 77 | target_arr: (n,) ground truth array 78 | cache_model: whether to cache the model and tokenizer for faster repeated calls 79 | Returns: NLL/D 80 | """ 81 | model, tokenizer = get_model_and_tokenizer(model, cache_model=cache_model) 82 | 83 | input_str = serialize_arr(vmap(transform)(input_arr), settings) 84 | target_str = serialize_arr(vmap(transform)(target_arr), settings) 85 | full_series = input_str + target_str 86 | 87 | batch = tokenizer( 88 | [full_series], 89 | return_tensors="pt", 90 | add_special_tokens=True 91 | ) 92 | batch = {k: v.cuda() for k, v in batch.items()} 93 | 94 | with torch.no_grad(): 95 | out = model(**batch) 96 | 97 | good_tokens_str = list("0123456789" + settings.time_sep) 98 | good_tokens = [tokenizer.convert_tokens_to_ids(token) for token in good_tokens_str] 99 | bad_tokens = [i for i in range(len(tokenizer)) if i not in good_tokens] 100 | out['logits'][:,:,bad_tokens] = -100 101 | 102 | input_ids = batch['input_ids'][0][1:] 103 | logprobs = torch.nn.functional.log_softmax(out['logits'], dim=-1)[0][:-1] 104 | logprobs = logprobs[torch.arange(len(input_ids)), input_ids].cpu().numpy() 105 | 106 | tokens = tokenizer.batch_decode( 107 | input_ids, 108 | skip_special_tokens=False, 109 | clean_up_tokenization_spaces=False 110 | ) 111 | 112 | input_len = len(tokenizer([input_str], return_tensors="pt",)['input_ids'][0]) 113 | input_len = input_len - 2 # remove the BOS token 114 | 115 | logprobs = logprobs[input_len:] 116 | tokens = tokens[input_len:] 117 | BPD = -logprobs.sum()/len(target_arr) 118 | 119 | #print("BPD unadjusted:", -logprobs.sum()/len(target_arr), "BPD adjusted:", BPD) 120 | # log p(x) = log p(token) - log bin_width = log p(token) + prec * log base 121 | transformed_nll = BPD - settings.prec*np.log(settings.base) 122 | avg_logdet_dydx = np.log(vmap(grad(transform))(target_arr)).mean() 123 | return transformed_nll-avg_logdet_dydx 124 | 125 | def llama_completion_fn( 126 | model, 127 | input_str, 128 | steps, 129 | settings, 130 | batch_size=5, 131 | num_samples=20, 132 | temp=0.9, 133 | top_p=0.9, 134 | cache_model=True 135 | ): 136 | avg_tokens_per_step = len(tokenize_fn(input_str, model)['input_ids']) / len(input_str.split(settings.time_sep)) 137 | max_tokens = int(avg_tokens_per_step*steps) 138 | 139 | model, tokenizer = get_model_and_tokenizer(model, cache_model=cache_model) 140 | 141 | gen_strs = [] 142 | for _ in tqdm(range(num_samples // batch_size)): 143 | batch = tokenizer( 144 | [input_str], 145 | return_tensors="pt", 146 | ) 147 | 148 | batch = {k: v.repeat(batch_size, 1) for k, v in batch.items()} 149 | batch = {k: v.cuda() for k, v in batch.items()} 150 | num_input_ids = batch['input_ids'].shape[1] 151 | 152 | good_tokens_str = list("0123456789" + settings.time_sep) 153 | good_tokens = [tokenizer.convert_tokens_to_ids(token) for token in good_tokens_str] 154 | # good_tokens += [tokenizer.eos_token_id] 155 | bad_tokens = [i for i in range(len(tokenizer)) if i not in good_tokens] 156 | 157 | generate_ids = model.generate( 158 | **batch, 159 | do_sample=True, 160 | max_new_tokens=max_tokens, 161 | temperature=temp, 162 | top_p=top_p, 163 | bad_words_ids=[[t] for t in bad_tokens], # 表示其他的不要生成 (这里可能控制不要生成其他符号) 164 | renormalize_logits=True, 165 | ) 166 | gen_strs += tokenizer.batch_decode( 167 | generate_ids[:, num_input_ids:], 168 | skip_special_tokens=True, 169 | clean_up_tokenization_spaces=False 170 | ) 171 | return gen_strs 172 | -------------------------------------------------------------------------------- /data1/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from jax import vmap 3 | import jax.numpy as jnp 4 | 5 | def quantile_loss(target, pred, q): 6 | q_pred = jnp.quantile(pred, q, axis=0) 7 | return 2 * jnp.sum( 8 | jnp.abs((q_pred - target) * ((target <= q_pred) * 1.0 - q)) 9 | ) 10 | 11 | def calculate_crps(target, pred, num_quantiles=20): 12 | quantiles = jnp.linspace(0, 1.0, num_quantiles+1)[1:] 13 | vec_quantile_loss = vmap(lambda q: quantile_loss(target, pred, q)) 14 | crps = jnp.sum(vec_quantile_loss(quantiles)) 15 | crps = crps / (jnp.sum(np.abs(target)) * len(quantiles)) 16 | return crps 17 | 18 | import jax 19 | from jax import grad,vmap 20 | from .serialize import serialize_arr, SerializerSettings 21 | import openai 22 | 23 | def nll(input_arr, target_arr, model, settings:SerializerSettings, transform, count_seps=True, prompt=None, temp=1): 24 | """ Returns the NLL/dimension (log base e) of the target array (continuous) according to the LM 25 | conditioned on the input array. Applies relevant log determinant for transforms and 26 | converts from discrete NLL of the LLM to continuous by assuming uniform within the bins. 27 | inputs: 28 | input_arr: (n,) context array 29 | target_arr: (n,) ground truth array 30 | Returns: NLL/D 31 | """ 32 | input_str = serialize_arr(vmap(transform)(input_arr), settings) 33 | target_str = serialize_arr(vmap(transform)(target_arr), settings) 34 | if prompt: 35 | input_str = prompt + '\n' + input_str 36 | if not input_str.endswith(settings.time_sep): 37 | print('Appending time separator to input... Are you sure you want this?') 38 | prompt = input_str + settings.time_sep + target_str 39 | else: 40 | prompt = input_str + target_str 41 | response = openai.Completion.create(model=model, prompt=prompt, logprobs=5, max_tokens=0, echo=True, temperature=temp) 42 | #print(response['choices'][0]) 43 | logprobs = np.array(response['choices'][0].logprobs.token_logprobs, dtype=np.float32) 44 | tokens = np.array(response['choices'][0].logprobs.tokens) 45 | top5logprobs = response['choices'][0].logprobs.top_logprobs 46 | seps = tokens==settings.time_sep 47 | target_start = np.argmax(np.cumsum(seps)==len(input_arr)) + 1 48 | logprobs = logprobs[target_start:] 49 | tokens = tokens[target_start:] 50 | top5logprobs = top5logprobs[target_start:] 51 | seps = tokens==settings.time_sep 52 | assert len(logprobs[seps]) == len(target_arr), f'There should be one separator per target. Got {len(logprobs[seps])} separators and {len(target_arr)} targets.' 53 | #adjust logprobs by removing extraneous and renormalizing (see appendix of paper) 54 | # logp' = logp - log(1-pk*pextra) 55 | allowed_tokens = [settings.bit_sep + str(i) for i in range(settings.base)] 56 | allowed_tokens += [settings.time_sep, settings.plus_sign, settings.minus_sign, settings.bit_sep+settings.decimal_point] 57 | allowed_tokens = {t for t in allowed_tokens if len(t) > 0} 58 | 59 | p_extra = np.array([sum(np.exp(ll) for k,ll in top5logprobs[i].items() if not (k in allowed_tokens)) for i in range(len(top5logprobs))]) 60 | if settings.bit_sep == '': 61 | p_extra = 0 62 | adjusted_logprobs = logprobs - np.log(1-p_extra) 63 | digits_bits = -adjusted_logprobs[~seps].sum() 64 | seps_bits = -adjusted_logprobs[seps].sum() 65 | BPD = digits_bits/len(target_arr) 66 | if count_seps: 67 | BPD += seps_bits/len(target_arr) 68 | #print("BPD unadjusted:", -logprobs.sum()/len(target_arr), "BPD adjusted:", BPD) 69 | # log p(x) = log p(token) - log bin_width = log p(token) + prec * log base 70 | transformed_nll = BPD - settings.prec*np.log(settings.base) 71 | avg_logdet_dydx = np.log(vmap(grad(transform))(target_arr)).mean() 72 | return transformed_nll-avg_logdet_dydx 73 | 74 | class Evaluator: 75 | 76 | def __init__(self): 77 | self.non_numerical_cols = [ 78 | "serialized_history", 79 | "serialized_target", 80 | "serialized_prediction", 81 | "history_len", 82 | "num_channels", 83 | "example_num", 84 | "sample_num", 85 | ] 86 | 87 | def evaluate_df(self, gt_df, pred_df): 88 | cols = [c for c in gt_df.columns if c not in self.non_numerical_cols] 89 | num_channels = gt_df["num_channels"].iloc[0] 90 | history_len = gt_df["history_len"].iloc[0] 91 | gt_vals = gt_df[cols].to_numpy().reshape(len(gt_df), -1, num_channels) # (num_examples, history_len + target_len, num_channels) 92 | gt_vals = gt_vals[:, history_len:, :] # (num_examples, target_len, num_channels) 93 | 94 | cols = [c for c in pred_df.columns if c not in self.non_numerical_cols] 95 | num_channels = pred_df["num_channels"].iloc[0] 96 | pred_df = pred_df[cols + ["example_num"]] 97 | 98 | all_pred_vals = [] 99 | for example_num in sorted(pred_df["example_num"].unique()): 100 | pred_vals = pred_df[pred_df["example_num"] == example_num][cols].to_numpy() # (num_samples, target_len * num_channels) 101 | pred_vals = pred_vals.reshape(pred_vals.shape[0], -1, num_channels) # (num_samples, target_len, num_channels) 102 | all_pred_vals.append(pred_vals) 103 | 104 | pred_vals = np.stack(all_pred_vals, axis=1) # (num_samples, num_examples, target_len, num_channels) 105 | assert gt_vals.shape == pred_vals.shape[1:] 106 | 107 | diff = (gt_vals[None] - pred_vals) # (num_samples, num_examples, target_len, num_channels) 108 | mse = np.mean(diff**2) 109 | mae = np.mean(np.abs(diff)) 110 | crps = calculate_crps(gt_vals, pred_vals) 111 | 112 | return { 113 | "mse": mse, 114 | "mae": mae, 115 | "crps": crps, 116 | } 117 | 118 | def evaluate(self, gt, pred): 119 | ''' 120 | gt: (batch_size, steps) 121 | pred: (batch_size, num_samples, steps) 122 | ''' 123 | assert gt.shape == (pred.shape[0], pred.shape[2]), f"wrong shapes: gt.shape: {gt.shape}, pred.shape: {pred.shape}" 124 | diff = (gt[:, None, :] - pred) # (batch_size, num_samples, steps) 125 | mse = np.mean(diff**2) 126 | mae = np.mean(np.abs(diff)) 127 | std = np.std(gt, axis=1) + 1e-8 # (batch_size,) 128 | normlized_diff = diff / std[:, None, None] # (batch_size, num_samples, steps) 129 | nmse = np.mean(normlized_diff**2) 130 | nmae = np.mean(np.abs(normlized_diff)) 131 | 132 | return { 133 | "nmse": nmse, 134 | "nmae": nmae, 135 | "mse": mse, 136 | "mae": mae, 137 | } -------------------------------------------------------------------------------- /models/validation_likelihood_tuning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm.auto import tqdm 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | from models.utils import grid_iter 5 | from dataclasses import is_dataclass 6 | from typing import Any 7 | 8 | def make_validation_dataset(train, n_val, val_length): 9 | """Partition the training set into training and validation sets. 10 | 11 | Args: 12 | train (list): List of time series data for training. 13 | n_val (int): Number of validation samples. 14 | val_length (int): Length of each validation sample. 15 | 16 | Returns: 17 | tuple: Lists of training data without validation, validation data, and the number of validation samples. 18 | """ 19 | assert isinstance(train, list), 'Train should be a list of series' 20 | 21 | train_minus_val_list, val_list = [], [] 22 | if n_val is None: 23 | n_val = len(train) 24 | for train_series in train[:n_val]: 25 | train_len = max(len(train_series) - val_length, 1) 26 | train_minus_val, val = train_series[:train_len], train_series[train_len:] 27 | print(f'Train length: {len(train_minus_val)}, Val length: {len(val)}') 28 | train_minus_val_list.append(train_minus_val) 29 | val_list.append(val) 30 | 31 | return train_minus_val_list, val_list, n_val 32 | 33 | 34 | def evaluate_hyper(hyper, train_minus_val, val, get_predictions_fn): 35 | """Evaluate a set of hyperparameters on the validation set. 36 | 37 | Args: 38 | hyper (dict): Dictionary of hyperparameters to evaluate. 39 | train_minus_val (list): List of training samples minus validation samples. 40 | val (list): List of validation samples. 41 | get_predictions_fn (callable): Function to get predictions. 42 | 43 | Returns: 44 | float: NLL/D value for the given hyperparameters, averaged over each series. 45 | """ 46 | assert isinstance(train_minus_val, list) and isinstance(val, list), 'Train minus val and val should be lists of series' 47 | return get_predictions_fn(train_minus_val, val, **hyper, num_samples=0)['NLL/D'] 48 | 49 | 50 | def get_autotuned_predictions_data(train, test, hypers, num_samples, get_predictions_fn, verbose=False, parallel=True, n_train=None, n_val=None, whether_blanket=True, genai_key=None): 51 | """ 52 | Automatically tunes hyperparameters based on validation likelihood and retrieves predictions using the best hyperparameters. 53 | The validation set is constructed on the fly by splitting the training set. 54 | 55 | Args: 56 | train (list): List of time series training data. 57 | test (list): List of time series test data. 58 | hypers (Union[dict, list]): Either a dictionary specifying the grid search or an explicit list of hyperparameter settings. 59 | num_samples (int): Number of samples to retrieve. 60 | get_predictions_fn (callable): Function used to get predictions based on provided hyperparameters. 61 | verbose (bool, optional): If True, prints out detailed information during the tuning process. Defaults to False. 62 | parallel (bool, optional): If True, parallelizes the hyperparameter tuning process. Defaults to True. 63 | n_train (int, optional): Number of training samples to use. Defaults to None. 64 | n_val (int, optional): Number of validation samples to use. Defaults to None. 65 | 66 | Returns: 67 | dict: Dictionary containing predictions, best hyperparameters, and other related information. 68 | """ 69 | if isinstance(hypers, dict): 70 | hypers = list(grid_iter(hypers)) 71 | else: 72 | assert isinstance(hypers, list), 'Hypers must be a list or dict' 73 | if not isinstance(train, list): 74 | train = [train] 75 | test = [test] 76 | if n_val is None: 77 | n_val = len(train) 78 | if len(hypers) > 1: 79 | val_length = min(len(test[0]), int(np.mean([len(series) for series in train]) / 2)) 80 | # Calculate validation set length, taking the smaller of test data length or half of average training data length 81 | train_minus_val, val, n_val = make_validation_dataset(train, n_val=n_val, val_length=val_length) 82 | # Remove validation series with smaller length than required val_length 83 | train_minus_val, val = zip(*[(train_series, val_series) for train_series, val_series in zip(train_minus_val, val) if len(val_series) == val_length]) 84 | train_minus_val = list(train_minus_val) 85 | val = list(val) 86 | if len(train_minus_val) <= int(0.9 * n_val): # Threshold of 0.9 to ensure sufficient data 87 | raise ValueError(f'Removed too many validation series. Only {len(train_minus_val)} out of {len(n_val)} series have length >= {val_length}. Try decreasing val_length.') 88 | val_nlls = [] 89 | 90 | def eval_hyper(hyper): 91 | try: 92 | return hyper, evaluate_hyper(hyper, train_minus_val, val, get_predictions_fn) 93 | except ValueError: 94 | return hyper, float('inf') 95 | 96 | best_val_nll = float('inf') 97 | best_hyper = None 98 | if not parallel: 99 | for hyper in tqdm(hypers, desc='Hyperparameter search'): 100 | _, val_nll = eval_hyper(hyper) 101 | val_nlls.append(val_nll) 102 | if val_nll < best_val_nll: 103 | best_val_nll = val_nll 104 | best_hyper = hyper 105 | if verbose: 106 | print(f'Hyper: {hyper} \n\t Val NLL: {val_nll:3f}') 107 | else: 108 | with ThreadPoolExecutor() as executor: 109 | futures = [executor.submit(eval_hyper, hyper) for hyper in hypers] 110 | for future in tqdm(as_completed(futures), total=len(hypers), desc='Hyperparameter search'): 111 | hyper, val_nll = future.result() 112 | val_nlls.append(val_nll) 113 | if val_nll < best_val_nll: 114 | best_val_nll = val_nll 115 | best_hyper = hyper 116 | if verbose: 117 | print(f'Hyper: {hyper} \n\t Val NLL: {val_nll:3f}') 118 | else: 119 | best_hyper = hypers[0] 120 | best_val_nll = float('inf') 121 | print(f'Sampling with best hyperparameters... {best_hyper} \n with NLL {best_val_nll:3f}') 122 | out = get_predictions_fn(train, test, **best_hyper, num_samples=num_samples, n_train=n_train, parallel=parallel, whether_blanket=whether_blanket, genai_key=genai_key) 123 | out['best_hyper'] = convert_to_dict(best_hyper) 124 | return out 125 | 126 | 127 | def convert_to_dict(obj: Any) -> Any: 128 | if isinstance(obj, dict): 129 | return {k: convert_to_dict(v) for k, v in obj.items()} 130 | elif isinstance(obj, list): 131 | return [convert_to_dict(elem) for elem in obj] 132 | elif is_dataclass(obj): 133 | return convert_to_dict(obj.__dict__) 134 | else: 135 | return obj 136 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numbers 3 | import random 4 | from collections import defaultdict 5 | from collections.abc import Iterable 6 | 7 | import itertools,operator,functools 8 | 9 | class FixedNumpySeed(object): 10 | def __init__(self, seed): 11 | self.seed = seed 12 | def __enter__(self): 13 | self.np_rng_state = np.random.get_state() 14 | np.random.seed(self.seed) 15 | self.rand_rng_state = random.getstate() 16 | random.seed(self.seed) 17 | def __exit__(self, *args): 18 | np.random.set_state(self.np_rng_state) 19 | random.setstate(self.rand_rng_state) 20 | 21 | class ReadOnlyDict(dict): 22 | def __readonly__(self, *args, **kwargs): 23 | raise RuntimeError("Cannot modify ReadOnlyDict") 24 | __setitem__ = __readonly__ 25 | __delitem__ = __readonly__ 26 | pop = __readonly__ 27 | popitem = __readonly__ 28 | clear = __readonly__ 29 | update = __readonly__ 30 | setdefault = __readonly__ 31 | del __readonly__ 32 | 33 | class NoGetItLambdaDict(dict): 34 | """ Regular dict, but refuses to __getitem__ pretending 35 | the element is not there and throws a KeyError 36 | if the value is a non string iterable or a lambda """ 37 | def __init__(self,d={}): 38 | super().__init__() 39 | for k,v in d.items(): 40 | if isinstance(v,dict): 41 | self[k] = NoGetItLambdaDict(v) 42 | else: 43 | self[k] = v 44 | def __getitem__(self, key): 45 | value = super().__getitem__(key) 46 | if callable(value) and value.__name__ == "": 47 | raise LookupError("You shouldn't try to retrieve lambda {} from this dict".format(value)) 48 | if isinstance(value,Iterable) and not isinstance(value,(str,bytes,dict,tuple)): 49 | raise LookupError("You shouldn't try to retrieve iterable {} from this dict".format(value)) 50 | return value 51 | 52 | # pop = __readonly__ 53 | # popitem = __readonly__ 54 | 55 | def sample_config(config_spec): 56 | """ Generates configs from the config spec. 57 | It will apply lambdas that depend on the config and sample from any 58 | iterables, make sure that no elements in the generated config are meant to 59 | be iterable or lambdas, strings are allowed.""" 60 | cfg_all = config_spec 61 | more_work=True 62 | i=0 63 | while more_work: 64 | cfg_all, more_work = _sample_config(cfg_all,NoGetItLambdaDict(cfg_all)) 65 | i+=1 66 | if i>10: 67 | raise RecursionError("config dependency unresolvable with {}".format(cfg_all)) 68 | out = defaultdict(dict) 69 | out.update(cfg_all) 70 | return out 71 | 72 | def _sample_config(config_spec,cfg_all): 73 | cfg = {} 74 | more_work = False 75 | for k,v in config_spec.items(): 76 | if isinstance(v,dict): 77 | new_dict,extra_work = _sample_config(v,cfg_all) 78 | cfg[k] = new_dict 79 | more_work |= extra_work 80 | elif isinstance(v,Iterable) and not isinstance(v,(str,bytes,dict,tuple)): 81 | cfg[k] = random.choice(v) 82 | elif callable(v) and v.__name__ == "": 83 | try:cfg[k] = v(cfg_all) 84 | except (KeyError, LookupError,Exception): 85 | cfg[k] = v # is used isntead of the variable it returns 86 | more_work = True 87 | else: cfg[k] = v 88 | return cfg, more_work 89 | 90 | def flatten(d, parent_key='', sep='/'): 91 | """An invertible dictionary flattening operation that does not clobber objs""" 92 | items = [] 93 | for k, v in d.items(): 94 | new_key = parent_key + sep + k if parent_key else k 95 | if isinstance(v, dict) and v: # non-empty dict 96 | items.extend(flatten(v, new_key, sep=sep).items()) 97 | else: 98 | items.append((new_key, v)) 99 | return dict(items) 100 | 101 | def unflatten(d,sep='/'): 102 | """Take a dictionary with keys {'k1/k2/k3':v} to {'k1':{'k2':{'k3':v}}} 103 | as outputted by flatten """ 104 | out_dict={} 105 | for k,v in d.items(): 106 | if isinstance(k,str): 107 | keys = k.split(sep) 108 | dict_to_modify = out_dict 109 | for partial_key in keys[:-1]: 110 | try: dict_to_modify = dict_to_modify[partial_key] 111 | except KeyError: 112 | dict_to_modify[partial_key] = {} 113 | dict_to_modify = dict_to_modify[partial_key] 114 | # Base level reached 115 | if keys[-1] in dict_to_modify: 116 | dict_to_modify[keys[-1]].update(v) 117 | else: 118 | dict_to_modify[keys[-1]] = v 119 | else: out_dict[k]=v 120 | return out_dict 121 | 122 | class grid_iter(object): 123 | """ Defines a length which corresponds to one full pass through the grid 124 | defined by grid variables in config_spec, but the iterator will continue iterating 125 | past that by repeating over the grid variables""" 126 | def __init__(self,config_spec,num_elements=-1,shuffle=True): 127 | self.cfg_flat = flatten(config_spec) 128 | is_grid_iterable = lambda v: (isinstance(v,Iterable) and not isinstance(v,(str,bytes,dict,tuple))) 129 | iterables = sorted({k:v for k,v in self.cfg_flat.items() if is_grid_iterable(v)}.items()) 130 | if iterables: self.iter_keys,self.iter_vals = zip(*iterables) 131 | else: self.iter_keys,self.iter_vals = [],[[]] 132 | self.vals = list(itertools.product(*self.iter_vals)) 133 | if shuffle: 134 | with FixedNumpySeed(0): random.shuffle(self.vals) 135 | self.num_elements = num_elements if num_elements>=0 else (-1*num_elements)*len(self) 136 | 137 | def __iter__(self): 138 | self.i=0 139 | self.vals_iter = iter(self.vals) 140 | return self 141 | def __next__(self): 142 | self.i+=1 143 | if self.i > self.num_elements: raise StopIteration 144 | if not self.vals: v = [] 145 | else: 146 | try: v = next(self.vals_iter) 147 | except StopIteration: 148 | self.vals_iter = iter(self.vals) 149 | v = next(self.vals_iter) 150 | chosen_iter_params = dict(zip(self.iter_keys,v)) 151 | self.cfg_flat.update(chosen_iter_params) 152 | return sample_config(unflatten(self.cfg_flat)) 153 | def __len__(self): 154 | product = functools.partial(functools.reduce, operator.mul) 155 | return product(len(v) for v in self.iter_vals) if self.vals else 1 156 | 157 | def flatten_dict(d): 158 | """ Flattens a dictionary, ignoring outer keys. Only 159 | numbers and strings allowed, others will be converted 160 | to a string. """ 161 | out = {} 162 | for k,v in d.items(): 163 | if isinstance(v,dict): 164 | out.update(flatten_dict(v)) 165 | elif isinstance(v,(numbers.Number,str,bytes)): 166 | out[k] = v 167 | else: 168 | out[k] = str(v) 169 | return out -------------------------------------------------------------------------------- /0_baseline_experiment_w_gemini.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "09937be4", 6 | "metadata": {}, 7 | "source": [ 8 | "## Experiments (baseline with GEMINI)\n", 9 | "This notebook is used for conduct the baseline prediction with the pre-trained LLMs. " 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "id": "initial_id", 16 | "metadata": { 17 | "ExecuteTime": { 18 | "end_time": "2024-04-10T13:56:49.155077500Z", 19 | "start_time": "2024-04-10T13:56:39.592714300Z" 20 | }, 21 | "collapsed": true, 22 | "jupyter": { 23 | "outputs_hidden": true 24 | } 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "from utils_others import *" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "4cf8aebe69a0f25a", 35 | "metadata": { 36 | "ExecuteTime": { 37 | "end_time": "2024-04-10T13:49:56.949729400Z", 38 | "start_time": "2024-04-10T13:49:56.936547200Z" 39 | }, 40 | "collapsed": false, 41 | "jupyter": { 42 | "outputs_hidden": false 43 | } 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "genai_api_key=\"\"\n", 48 | "# put your gemini api key here" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "7b92c7543f2eb2e8", 55 | "metadata": { 56 | "ExecuteTime": { 57 | "end_time": "2024-04-10T13:49:56.969729900Z", 58 | "start_time": "2024-04-10T13:49:56.954729400Z" 59 | }, 60 | "collapsed": false, 61 | "jupyter": { 62 | "outputs_hidden": false 63 | } 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "# 模型设定\n", 68 | "gemini_hypers = {\n", 69 | " # temp=[0.2, 0.4, 0.6, 0.8, 1.0],\n", 70 | " 'temp': 0.2, \n", 71 | " 'alpha': 0.95,\n", 72 | " 'beta': 0.3,\n", 73 | " 'basic': [False],\n", 74 | " 'settings': [SerializerSettings(base=10, prec=3, signed=True,half_bin_correction=True)],\n", 75 | " # prec可能=3\n", 76 | "}\n", 77 | "\n", 78 | "model_hypers = {\n", 79 | " 'gemini-1.0-pro': {'model': 'gemini-1.0-pro', **gemini_hypers},\n", 80 | " # 'gemini-pro': {'model': 'gemini-pro', **gemini_hypers}\n", 81 | "}\n", 82 | "\n", 83 | "model_predict_fns = {\n", 84 | " 'gemini-1.0-pro': get_llmtime_predictions_data, \n", 85 | " # 'gemini-pro': get_llmtime_predictions_data\n", 86 | "}\n" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "id": "e9dec56491e6776d", 93 | "metadata": { 94 | "ExecuteTime": { 95 | "end_time": "2024-04-10T13:49:57.142716900Z", 96 | "start_time": "2024-04-10T13:49:56.967728500Z" 97 | }, 98 | "collapsed": false, 99 | "jupyter": { 100 | "outputs_hidden": false 101 | } 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "datasets = get_datasets() \n", 106 | "datasets_tmp = get_memorization_datasets()\n", 107 | "datasets.update(datasets_tmp)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "a65033b97ace9eef", 114 | "metadata": { 115 | "ExecuteTime": { 116 | "end_time": "2024-04-10T13:49:57.153717800Z", 117 | "start_time": "2024-04-10T13:49:57.139719100Z" 118 | }, 119 | "collapsed": false, 120 | "jupyter": { 121 | "outputs_hidden": false 122 | } 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "dataset_name = 'WineDataset'\n", 127 | "train, test = datasets[dataset_name]\n", 128 | "\n", 129 | "num_samples = 10 \n" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "f273361186442fd9", 136 | "metadata": { 137 | "ExecuteTime": { 138 | "end_time": "2024-04-10T13:51:50.155195400Z", 139 | "start_time": "2024-04-10T13:49:57.154717400Z" 140 | }, 141 | "collapsed": false, 142 | "jupyter": { 143 | "outputs_hidden": false 144 | } 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "out_gemini_pro, out_gemini_pro_number = prediction_gemini(model_predict_fns, train, test, model_hypers, num_samples=num_samples, whether_blanket=False, dataset_name='WineDataset', genai_key=genai_api_key)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "dc5ddf307bbb9f16", 155 | "metadata": { 156 | "ExecuteTime": { 157 | "start_time": "2024-04-10T13:51:50.152110300Z" 158 | }, 159 | "collapsed": false, 160 | "jupyter": { 161 | "outputs_hidden": false 162 | } 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "plot_preds_w_train_test(train, test, out_gemini_pro_number['WineDataset'], model_name='gemini-1.0-pro', ds_name='WineDataset', show_samples=False)\n" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "3cd41c5000718220", 173 | "metadata": { 174 | "collapsed": false, 175 | "jupyter": { 176 | "outputs_hidden": false 177 | } 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "plot_preds_w_test(test, out_gemini_pro_number['WineDataset'], model_name='gemini-1.0-pro', ds_name='WineDataset', show_samples=False)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "2497807aaf9409a1", 188 | "metadata": { 189 | "collapsed": false, 190 | "jupyter": { 191 | "outputs_hidden": false 192 | } 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "\n", 197 | "mse_mean, mae_mean, mape_mean, r2_mean = metrics_used(test=test, dataset_name=dataset_name, original_pred=out_gemini_pro_number, num_samples=num_samples)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "id": "65a0035f", 203 | "metadata": {}, 204 | "source": [ 205 | "Hyperparameter Tuning" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "id": "a9d686eecd5587e5", 212 | "metadata": { 213 | "collapsed": false, 214 | "jupyter": { 215 | "outputs_hidden": false 216 | } 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "temp_list = [0.2, 0.4, 0.6, 0.8, 1.0]\n", 221 | "prec_list = [2, 3]\n", 222 | "\n", 223 | "opt_hyper_gemini(model_predict_fns, train, test, model_hypers, num_samples=num_samples, whether_blanket=False,\n", 224 | " dataset_name=dataset_name, genai_key=genai_api_key, temp_list=temp_list, prec_list=prec_list)" 225 | ] 226 | } 227 | ], 228 | "metadata": { 229 | "kernelspec": { 230 | "display_name": "Python 3 (ipykernel)", 231 | "language": "python", 232 | "name": "python3" 233 | }, 234 | "language_info": { 235 | "codemirror_mode": { 236 | "name": "ipython", 237 | "version": 3 238 | }, 239 | "file_extension": ".py", 240 | "mimetype": "text/x-python", 241 | "name": "python", 242 | "nbconvert_exporter": "python", 243 | "pygments_lexer": "ipython3", 244 | "version": "3.11.8" 245 | } 246 | }, 247 | "nbformat": 4, 248 | "nbformat_minor": 5 249 | } 250 | -------------------------------------------------------------------------------- /4_counterfactual_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "22dfaf15", 6 | "metadata": {}, 7 | "source": [ 8 | "### Counterfactual analysis" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "initial_id", 15 | "metadata": { 16 | "ExecuteTime": { 17 | "end_time": "2024-04-10T14:46:42.145299300Z", 18 | "start_time": "2024-04-10T14:46:30.911475100Z" 19 | }, 20 | "collapsed": true, 21 | "jupyter": { 22 | "outputs_hidden": true 23 | } 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "from utils_others import *\n", 28 | "import json" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "fbd2507b1c9346db", 35 | "metadata": { 36 | "ExecuteTime": { 37 | "end_time": "2024-04-10T14:46:42.167720700Z", 38 | "start_time": "2024-04-10T14:46:42.147815Z" 39 | }, 40 | "collapsed": false, 41 | "jupyter": { 42 | "outputs_hidden": false 43 | } 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "with open('config.json', 'r', encoding='utf-8') as f:\n", 48 | " config = json.load(f)\n", 49 | "\n", 50 | "api_key = config[\"GEMINI_API_KEY\"]\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "27711763b12dd009", 57 | "metadata": { 58 | "ExecuteTime": { 59 | "end_time": "2024-04-10T14:46:42.187294800Z", 60 | "start_time": "2024-04-10T14:46:42.166720600Z" 61 | }, 62 | "collapsed": false, 63 | "jupyter": { 64 | "outputs_hidden": false 65 | } 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "gemini_hypers = {\n", 70 | " 'temp': 0.2, \n", 71 | " 'alpha': 0.95,\n", 72 | " 'beta': 0.3,\n", 73 | " 'basic': [False],\n", 74 | " 'settings': [SerializerSettings(base=10, prec=3, signed=True,half_bin_correction=True)],\n", 75 | "}\n", 76 | "\n", 77 | "model_hypers = {\n", 78 | " 'gemini-1.0-pro': {'model': 'gemini-1.0-pro', **gemini_hypers},\n", 79 | "}\n", 80 | "\n", 81 | "model_predict_fns = {\n", 82 | " 'gemini-1.0-pro': get_llmtime_predictions_data,\n", 83 | "}\n" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "9ce01411623a3fec", 90 | "metadata": { 91 | "ExecuteTime": { 92 | "end_time": "2024-04-10T14:46:42.381691500Z", 93 | "start_time": "2024-04-10T14:46:42.179723400Z" 94 | }, 95 | "collapsed": false, 96 | "jupyter": { 97 | "outputs_hidden": false 98 | } 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "datasets = get_datasets() \n", 103 | "datasets_tmp = get_memorization_datasets()\n", 104 | "datasets.update(datasets_tmp)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "id": "ddd96fa2b4a8e03", 111 | "metadata": { 112 | "ExecuteTime": { 113 | "end_time": "2024-04-10T14:46:42.430087Z", 114 | "start_time": "2024-04-10T14:46:42.380698200Z" 115 | }, 116 | "collapsed": false, 117 | "jupyter": { 118 | "outputs_hidden": false 119 | } 120 | }, 121 | "outputs": [], 122 | "source": [ 123 | "# The free version of Gemini always throws an error (security failure) when running the IstanbulTraffic dataset. While other datasets occasionally throw errors.\n", 124 | "dataset_name = 'WineDataset' \n", 125 | "train, test = datasets[dataset_name]\n", 126 | "\n", 127 | "num_samples = 10 " 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "abafa4c2eab1f5ef", 134 | "metadata": { 135 | "ExecuteTime": { 136 | "end_time": "2024-04-10T14:46:42.431072100Z", 137 | "start_time": "2024-04-10T14:46:42.396856500Z" 138 | }, 139 | "collapsed": false, 140 | "jupyter": { 141 | "outputs_hidden": false 142 | } 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "def adding_noise(train, loc=0, steps=4, std=5):\n", 147 | " np.random.seed(42)\n", 148 | " noise = np.random.normal(0, std, steps)\n", 149 | " train[loc:(loc+steps)] += noise\n", 150 | " return train" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "f2c83fabad4deb83", 157 | "metadata": { 158 | "ExecuteTime": { 159 | "end_time": "2024-04-10T14:48:10.784627300Z", 160 | "start_time": "2024-04-10T14:46:42.421467100Z" 161 | }, 162 | "collapsed": false, 163 | "jupyter": { 164 | "outputs_hidden": false 165 | } 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "period = 12\n", 170 | "std = 0.1*(train.max()-train.min())\n", 171 | "num_iteration = int(len(train)/period)\n", 172 | "\n", 173 | "out_counterfactual = {}\n", 174 | "MSE =[]\n", 175 | "MAE =[]\n", 176 | "MAPE = []\n", 177 | "R = []\n", 178 | "\n", 179 | "for index in range(num_iteration):\n", 180 | " MSE_tmp =[]\n", 181 | " MAE_tmp =[]\n", 182 | " MAPE_tmp = []\n", 183 | " R_tmp = []\n", 184 | " train_tmp = copy.deepcopy(train)\n", 185 | " loc = index * period\n", 186 | " train_tmp = adding_noise(train_tmp, loc=loc, steps=period, std=std)\n", 187 | " if index > 0:\n", 188 | " time.sleep(60) \n", 189 | " out_gemini_pro, out_gemini_pro_number = prediction_gemini(model_predict_fns, train, test, model_hypers, num_samples=num_samples, whether_blanket=False, dataset_name=dataset_name, genai_key=api_key)\n", 190 | "\n", 191 | " mse_amount = 0.0\n", 192 | " mae_amount = 0.0\n", 193 | " mape_amount = 0.0\n", 194 | " rsquare_amount = 0.0\n", 195 | " for i in range(num_samples):\n", 196 | " # seq_pred = out_gemini_pro[dataset_name]['samples'].iloc[i, :] \n", 197 | " seq_pred = out_gemini_pro_number[dataset_name]['samples'].iloc[i, :] \n", 198 | " \n", 199 | " mse = mean_squared_error(test, seq_pred)\n", 200 | " mae = mean_absolute_error(test, seq_pred)\n", 201 | " mape = metrics.mean_absolute_percentage_error(test, seq_pred)*100\n", 202 | " r2 = r2_score(test, seq_pred)\n", 203 | " \n", 204 | " MSE_tmp.append(mse)\n", 205 | " MAE_tmp.append(mae)\n", 206 | " MAPE_tmp.append(mape)\n", 207 | " R_tmp.append(r2)\n", 208 | " \n", 209 | " MSE.append(MSE_tmp)\n", 210 | " MAE.append(MAE_tmp)\n", 211 | " MAPE.append(MAPE_tmp)\n", 212 | " R.append(R_tmp)\n" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "id": "229ee28327a4dce8", 219 | "metadata": { 220 | "ExecuteTime": { 221 | "end_time": "2024-04-10T14:48:13.909609900Z", 222 | "start_time": "2024-04-10T14:48:12.848055500Z" 223 | }, 224 | "collapsed": false, 225 | "jupyter": { 226 | "outputs_hidden": false 227 | } 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "fig_counterfactual(R, metric_name='R^2', dataset_name='WineDataset')" 232 | ] 233 | } 234 | ], 235 | "metadata": { 236 | "kernelspec": { 237 | "display_name": "Python 3 (ipykernel)", 238 | "language": "python", 239 | "name": "python3" 240 | }, 241 | "language_info": { 242 | "codemirror_mode": { 243 | "name": "ipython", 244 | "version": 3 245 | }, 246 | "file_extension": ".py", 247 | "mimetype": "text/x-python", 248 | "name": "python", 249 | "nbconvert_exporter": "python", 250 | "pygments_lexer": "ipython3", 251 | "version": "3.11.8" 252 | } 253 | }, 254 | "nbformat": 4, 255 | "nbformat_minor": 5 256 | } 257 | -------------------------------------------------------------------------------- /models/gpt.py: -------------------------------------------------------------------------------- 1 | from data1.serialize import serialize_arr, SerializerSettings 2 | import openai 3 | import tiktoken 4 | import numpy as np 5 | from jax import grad, vmap 6 | import google.generativeai as genai 7 | 8 | safety_settings = [ 9 | {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, 10 | {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, 11 | {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, 12 | {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, 13 | ] 14 | 15 | def tokenize_fn(string, model): 16 | """ 17 | Tokenize a string using the tokenizer for the specified GPT model. 18 | """ 19 | encoding = tiktoken.encoding_for_model(model) 20 | return encoding.encode(string) 21 | 22 | def get_allowed_ids(strings, model): 23 | """ 24 | Tokenize a list of strings using the tokenizer for the specified GPT model. 25 | """ 26 | encoding = tiktoken.encoding_for_model(model) 27 | ids = [] 28 | for s in strings: 29 | ids.extend(encoding.encode(s)) 30 | return ids 31 | 32 | def gemini_completion_fn(model, input_str, steps, settings, num_samples, temp, whether_blanket=True, genai_key=None): 33 | """ 34 | Generate text completions using Google's Gemini API. 35 | """ 36 | if genai_key is not None: 37 | genai.configure(api_key=genai_key, transport='rest') 38 | 39 | allowed_tokens = [settings.bit_sep + str(i) for i in range(settings.base)] 40 | allowed_tokens += [settings.time_sep, settings.plus_sign, settings.minus_sign] 41 | allowed_tokens = [t for t in allowed_tokens if len(t) > 0] 42 | 43 | if not whether_blanket: 44 | input_str = input_str.replace(" ", "") 45 | 46 | if model not in ['gemini-1.0-pro', 'gemini-pro']: 47 | logit_bias = {id: 30 for id in get_allowed_ids(allowed_tokens, model)} 48 | 49 | if model in ['gemini-1.0-pro', 'gemini-pro']: 50 | gemini_sys_message = ( 51 | f"You are a helpful assistant that performs time series predictions. " 52 | f"The user will provide a sequence and you will predict the remaining sequence for {steps * 10} steps. " 53 | f"The sequence is represented by decimal strings separated by commas, and each step consists of contents between two commas." 54 | ) 55 | extra_input = ( 56 | "Please continue the following sequence without producing any additional text. " 57 | "Do not include phrases like 'the next terms in the sequence are'. Just return the numbers.\n" 58 | "Sequence:\n" 59 | ) 60 | 61 | content_fin = [] 62 | model = genai.GenerativeModel(model) 63 | for i in range(num_samples): 64 | print("Index:", i) 65 | response = model.generate_content( 66 | contents=gemini_sys_message + extra_input + input_str + settings.time_sep, 67 | generation_config=genai.types.GenerationConfig( 68 | temperature=temp, 69 | ), 70 | safety_settings=safety_settings 71 | ) 72 | tmp = response.text 73 | if not whether_blanket: 74 | tmp = ' '.join(response.text) 75 | content_fin.append(tmp) 76 | return content_fin 77 | else: 78 | assert False 79 | 80 | def gpt_completion_fn(model, input_str, steps, settings, num_samples, temp, whether_blanket=True): 81 | """ 82 | Generate text completions using OpenAI GPT models. 83 | """ 84 | avg_tokens_per_step = len(tokenize_fn(input_str, model)) / len(input_str.split(settings.time_sep)) 85 | 86 | allowed_tokens = [settings.bit_sep + str(i) for i in range(settings.base)] 87 | allowed_tokens += [settings.time_sep, settings.plus_sign, settings.minus_sign] 88 | allowed_tokens = [t for t in allowed_tokens if len(t) > 0] 89 | 90 | logit_bias = {} 91 | if model not in ['gpt-3.5-turbo', 'gpt-4']: 92 | logit_bias = {id: 30 for id in get_allowed_ids(allowed_tokens, model)} 93 | 94 | if model in ['gpt-3.5-turbo', 'gpt-4']: 95 | chatgpt_sys_message = ( 96 | "You are a helpful assistant that performs time series predictions. " 97 | "The user will provide a sequence and you will predict the remaining sequence. " 98 | "The sequence is represented by decimal strings separated by commas." 99 | ) 100 | extra_input = ( 101 | "Please continue the following sequence without producing any additional text. " 102 | "Do not include phrases like 'the next terms in the sequence are'. Just return the numbers.\n" 103 | "Sequence:\n" 104 | ) 105 | response = openai.ChatCompletion.create( 106 | model=model, 107 | messages=[ 108 | {"role": "system", "content": chatgpt_sys_message}, 109 | {"role": "user", "content": extra_input + input_str + settings.time_sep} 110 | ], 111 | max_tokens=int(avg_tokens_per_step * steps), 112 | temperature=temp, 113 | logit_bias=logit_bias, 114 | n=num_samples 115 | ) 116 | return [choice.message.content for choice in response.choices] 117 | else: 118 | response = openai.Completion.create( 119 | model=model, 120 | prompt=input_str, 121 | max_tokens=int(avg_tokens_per_step * steps), 122 | temperature=temp, 123 | logit_bias=logit_bias, 124 | n=num_samples 125 | ) 126 | return [choice.text for choice in response.choices] 127 | 128 | def gpt_nll_fn(model, input_arr, target_arr, settings: SerializerSettings, transform, count_seps=True, temp=1): 129 | """ 130 | Compute the Negative Log-Likelihood (NLL) per dimension under the LLM. 131 | """ 132 | input_str = serialize_arr(vmap(transform)(input_arr), settings) 133 | target_str = serialize_arr(vmap(transform)(target_arr), settings) 134 | assert input_str.endswith(settings.time_sep), f'Input string must end with {settings.time_sep}, got {input_str}' 135 | 136 | full_series = input_str + target_str 137 | response = openai.Completion.create( 138 | model=model, 139 | prompt=full_series, 140 | logprobs=5, 141 | max_tokens=0, 142 | echo=True, 143 | temperature=temp 144 | ) 145 | logprobs = np.array(response['choices'][0].logprobs.token_logprobs, dtype=np.float32) 146 | tokens = np.array(response['choices'][0].logprobs.tokens) 147 | top5logprobs = response['choices'][0].logprobs.top_logprobs 148 | 149 | seps = tokens == settings.time_sep 150 | target_start = np.argmax(np.cumsum(seps) == len(input_arr)) + 1 151 | logprobs = logprobs[target_start:] 152 | tokens = tokens[target_start:] 153 | top5logprobs = top5logprobs[target_start:] 154 | seps = tokens == settings.time_sep 155 | 156 | assert len(logprobs[seps]) == len(target_arr), ( 157 | f'There should be one separator per target. Got {len(logprobs[seps])} separators and {len(target_arr)} targets.' 158 | ) 159 | 160 | allowed_tokens = [settings.bit_sep + str(i) for i in range(settings.base)] 161 | allowed_tokens += [ 162 | settings.time_sep, 163 | settings.plus_sign, 164 | settings.minus_sign, 165 | settings.bit_sep + settings.decimal_point 166 | ] 167 | allowed_tokens = {t for t in allowed_tokens if len(t) > 0} 168 | 169 | p_extra = np.array([ 170 | sum(np.exp(lp) for k, lp in top5logprobs[i].items() if k not in allowed_tokens) 171 | for i in range(len(top5logprobs)) 172 | ]) 173 | if settings.bit_sep == '': 174 | p_extra = 0 175 | 176 | adjusted_logprobs = logprobs - np.log(1 - p_extra) 177 | digits_bits = -adjusted_logprobs[~seps].sum() 178 | seps_bits = -adjusted_logprobs[seps].sum() 179 | 180 | BPD = digits_bits / len(target_arr) 181 | if count_seps: 182 | BPD += seps_bits / len(target_arr) 183 | 184 | transformed_nll = BPD - settings.prec * np.log(settings.base) 185 | avg_logdet_dydx = np.log(vmap(grad(transform))(target_arr)).mean() 186 | return transformed_nll - avg_logdet_dydx 187 | -------------------------------------------------------------------------------- /data1/serialize.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | from dataclasses import dataclass 4 | 5 | def vec_num2repr(val, base, prec, max_val): 6 | """ 7 | Convert numbers to a representation in a specified base with precision. 8 | 9 | Parameters: 10 | - val (np.array): The numbers to represent. 11 | - base (int): The base of the representation. 12 | - prec (int): The precision after the 'decimal' point in the base representation. 13 | - max_val (float): The maximum absolute value of the number. 14 | 15 | Returns: 16 | - tuple: Sign and digits in the specified base representation. 17 | 18 | Examples: 19 | With base=10, prec=2: 20 | 0.5 -> 50 21 | 3.52 -> 352 22 | 12.5 -> 1250 23 | """ 24 | base = float(base) 25 | bs = val.shape[0] 26 | sign = 1 * (val >= 0) - 1 * (val < 0) 27 | val = np.abs(val) 28 | max_bit_pos = int(np.ceil(np.log(max_val) / np.log(base)).item()) 29 | 30 | before_decimals = [] 31 | for i in range(max_bit_pos): 32 | digit = (val / base**(max_bit_pos - i - 1)).astype(int) 33 | before_decimals.append(digit) 34 | val -= digit * base**(max_bit_pos - i - 1) 35 | 36 | before_decimals = np.stack(before_decimals, axis=-1) 37 | 38 | if prec > 0: 39 | after_decimals = [] 40 | for i in range(prec): 41 | digit = (val / base**(-i - 1)).astype(int) 42 | after_decimals.append(digit) 43 | val -= digit * base**(-i - 1) 44 | 45 | after_decimals = np.stack(after_decimals, axis=-1) 46 | digits = np.concatenate([before_decimals, after_decimals], axis=-1) 47 | else: 48 | digits = before_decimals 49 | return sign, digits 50 | 51 | def vec_repr2num(sign, digits, base, prec, half_bin_correction=True): 52 | """ 53 | Convert a string representation in a specified base back to numbers. 54 | 55 | Parameters: 56 | - sign (np.array): The sign of the numbers. 57 | - digits (np.array): Digits of the numbers in the specified base. 58 | - base (int): The base of the representation. 59 | - prec (int): The precision after the 'decimal' point in the base representation. 60 | - half_bin_correction (bool): If True, adds 0.5 of the smallest bin size to the number. 61 | 62 | Returns: 63 | - np.array: Numbers corresponding to the given base representation. 64 | """ 65 | base = float(base) 66 | bs, D = digits.shape 67 | digits_flipped = np.flip(digits, axis=-1) 68 | powers = -np.arange(-prec, -prec + D) 69 | val = np.sum(digits_flipped/base**powers, axis=-1) 70 | 71 | if half_bin_correction: 72 | val += 0.5/base**prec 73 | 74 | return sign * val 75 | 76 | @dataclass 77 | class SerializerSettings: 78 | """ 79 | Settings for serialization of numbers. 80 | 81 | Attributes: 82 | - base (int): The base for number representation. 83 | - prec (int): The precision after the 'decimal' point in the base representation. 84 | - signed (bool): If True, allows negative numbers. Default is False. 85 | - fixed_length (bool): If True, ensures fixed length of serialized string. Default is False. 86 | - max_val (float): Maximum absolute value of number for serialization. 87 | - time_sep (str): Separator for different time steps. 88 | - bit_sep (str): Separator for individual digits. 89 | - plus_sign (str): String representation for positive sign. 90 | - minus_sign (str): String representation for negative sign. 91 | - half_bin_correction (bool): If True, applies half bin correction during deserialization. Default is True. 92 | - decimal_point (str): String representation for the decimal point. 93 | """ 94 | base: int = 10 95 | prec: int = 3 96 | signed: bool = True 97 | fixed_length: bool = False 98 | max_val: float = 1e7 99 | time_sep: str = ' ,' 100 | bit_sep: str = ' ' 101 | plus_sign: str = '' 102 | minus_sign: str = ' -' 103 | half_bin_correction: bool = True 104 | decimal_point: str = '' 105 | missing_str: str = ' Nan' 106 | 107 | def serialize_arr(arr, settings: SerializerSettings): 108 | """ 109 | Serialize an array of numbers (a time series) into a string based on the provided settings. 110 | 111 | Parameters: 112 | - arr (np.array): Array of numbers to serialize. 113 | - settings (SerializerSettings): Settings for serialization. 114 | 115 | Returns: 116 | - str: String representation of the array. 117 | """ 118 | # max_val is only for fixing the number of bits in nunm2repr so it can be vmapped 119 | assert np.all(np.abs(arr[~np.isnan(arr)]) <= settings.max_val), f"abs(arr) must be <= max_val,\ 120 | but abs(arr)={np.abs(arr)}, max_val={settings.max_val}" 121 | 122 | if not settings.signed: 123 | assert np.all(arr[~np.isnan(arr)] >= 0), f"unsigned arr must be >= 0" 124 | plus_sign = minus_sign = '' 125 | else: 126 | plus_sign = settings.plus_sign 127 | minus_sign = settings.minus_sign 128 | 129 | vnum2repr = partial(vec_num2repr,base=settings.base,prec=settings.prec,max_val=settings.max_val) 130 | sign_arr, digits_arr = vnum2repr(np.where(np.isnan(arr),np.zeros_like(arr),arr)) 131 | ismissing = np.isnan(arr) 132 | 133 | def tokenize(arr): 134 | return ''.join([settings.bit_sep+str(b) for b in arr]) 135 | 136 | bit_strs = [] 137 | for sign, digits,missing in zip(sign_arr, digits_arr, ismissing): 138 | if not settings.fixed_length: 139 | # remove leading zeros 140 | nonzero_indices = np.where(digits != 0)[0] 141 | if len(nonzero_indices) == 0: 142 | digits = np.array([0]) 143 | else: 144 | digits = digits[nonzero_indices[0]:] 145 | # add a decimal point 146 | prec = settings.prec 147 | if len(settings.decimal_point): 148 | digits = np.concatenate([digits[:-prec], np.array([settings.decimal_point]), digits[-prec:]]) 149 | digits = tokenize(digits) 150 | sign_sep = plus_sign if sign==1 else minus_sign 151 | if missing: 152 | bit_strs.append(settings.missing_str) 153 | else: 154 | bit_strs.append(sign_sep + digits) 155 | bit_str = settings.time_sep.join(bit_strs) 156 | bit_str += settings.time_sep # otherwise there is ambiguity in number of digits in the last time step 157 | return bit_str 158 | 159 | def deserialize_str(bit_str, settings: SerializerSettings, ignore_last=False, steps=None): 160 | """ 161 | Deserialize a string into an array of numbers (a time series) based on the provided settings. 162 | 163 | Parameters: 164 | - bit_str (str): String representation of an array of numbers. 165 | - settings (SerializerSettings): Settings for deserialization. 166 | - ignore_last (bool): If True, ignores the last time step in the string (which may be incomplete due to token limit etc.). Default is False. 167 | - steps (int, optional): Number of steps or entries to deserialize. 168 | 169 | Returns: 170 | - None if deserialization failed for the very first number, otherwise 171 | - np.array: Array of numbers corresponding to the string. 172 | """ 173 | # ignore_last is for ignoring the last time step in the prediction, which is often a partially generated due to token limit 174 | orig_bitstring = bit_str 175 | bit_strs = bit_str.split(settings.time_sep) 176 | # remove empty strings 177 | bit_strs = [a for a in bit_strs if len(a) > 0] 178 | if ignore_last: 179 | bit_strs = bit_strs[:-1] 180 | if steps is not None: 181 | bit_strs = bit_strs[:steps] 182 | vrepr2num = partial(vec_repr2num,base=settings.base,prec=settings.prec,half_bin_correction=settings.half_bin_correction) 183 | max_bit_pos = int(np.ceil(np.log(settings.max_val)/np.log(settings.base)).item()) 184 | sign_arr = [] 185 | digits_arr = [] 186 | try: 187 | for i, bit_str in enumerate(bit_strs): 188 | if bit_str.startswith(settings.minus_sign): 189 | sign = -1 190 | elif bit_str.startswith(settings.plus_sign): 191 | sign = 1 192 | else: 193 | assert settings.signed == False, f"signed bit_str must start with {settings.minus_sign} or {settings.plus_sign}" 194 | bit_str = bit_str[len(settings.plus_sign):] if sign==1 else bit_str[len(settings.minus_sign):] 195 | if settings.bit_sep=='': 196 | bits = [b for b in bit_str.lstrip()] 197 | else: 198 | bits = [b[:1] for b in bit_str.lstrip().split(settings.bit_sep)] 199 | if settings.fixed_length: 200 | assert len(bits) == max_bit_pos+settings.prec, f"fixed length bit_str must have {max_bit_pos+settings.prec} bits, but has {len(bits)}: '{bit_str}'" 201 | digits = [] 202 | for b in bits: 203 | if b==settings.decimal_point: 204 | continue 205 | # check if is a digit 206 | if b.isdigit(): 207 | digits.append(int(b)) 208 | else: 209 | break 210 | #digits = [int(b) for b in bits] 211 | sign_arr.append(sign) 212 | digits_arr.append(digits) 213 | except Exception as e: 214 | print(f"Error deserializing {settings.time_sep.join(bit_strs[i-2:i+5])}{settings.time_sep}\n\t{e}") 215 | print(f'Got {orig_bitstring}') 216 | print(f"Bitstr {bit_str}, separator {settings.bit_sep}") 217 | # At this point, we have already deserialized some of the bit_strs, so we return those below 218 | if digits_arr: 219 | # add leading zeros to get to equal lengths 220 | max_len = max([len(d) for d in digits_arr]) 221 | for i in range(len(digits_arr)): 222 | digits_arr[i] = [0]*(max_len-len(digits_arr[i])) + digits_arr[i] 223 | return vrepr2num(np.array(sign_arr), np.array(digits_arr)) 224 | else: 225 | # errored at first step 226 | return None 227 | -------------------------------------------------------------------------------- /models/darts.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from darts import TimeSeries 3 | import darts.models 4 | import numpy as np 5 | from darts.utils.likelihood_models import LaplaceLikelihood, GaussianLikelihood 6 | #from darts.dataprocessing.transformers import Scaler 7 | from sklearn.preprocessing import MinMaxScaler 8 | import torch 9 | 10 | likelihoods = {'laplace': LaplaceLikelihood(), 'gaussian': GaussianLikelihood()} 11 | 12 | def get_TCN_predictions_data(*args,**kwargs): 13 | out = get_chunked_AR_predictions_data(darts.models.TCNModel,*args,**kwargs) 14 | out['info']['Method'] = 'TCN' 15 | return out 16 | 17 | def get_NHITS_predictions_data(*args,**kwargs): 18 | out = get_chunked_AR_predictions_data(darts.models.NHiTSModel,*args,**kwargs) 19 | out['info']['Method'] = 'NHiTS' 20 | return out 21 | 22 | def get_NBEATS_predictions_data(*args,**kwargs): 23 | out = get_chunked_AR_predictions_data(darts.models.NBEATSModel,*args,**kwargs) 24 | out['info']['Method'] = 'NBEATS' 25 | return out 26 | 27 | def get_chunked_AR_predictions_data(modeltype,train,test, epochs=400, in_len=12, out_len=12, likelihood='laplace', num_samples=100, n_train=None, **kwargs): 28 | if not isinstance(train, list): 29 | # assume single train/test case 30 | train = [train] 31 | test = [test] 32 | for i in range(len(train)): 33 | # model expects training data1 to have len at least in_len+out_len 34 | in_len = min(in_len,len(train[i])-out_len) 35 | assert in_len > 0, f'Input length must be greater than 0, got {in_len} after subtracting out_len={out_len} from len(train)={len(train)}' 36 | if not isinstance(train[i], pd.Series): 37 | train[i] = pd.Series(train[i], index = pd.RangeIndex(len(train[i]))) 38 | test[i] = pd.Series(test[i], index = pd.RangeIndex(len(train[i]),len(test[i])+len(train[i]))) 39 | 40 | test_len = len(test[0]) 41 | assert all(len(t)==test_len for t in test), f'All test series must have same length, got {[len(t) for t in test]}' 42 | 43 | model = modeltype( 44 | input_chunk_length=in_len, 45 | output_chunk_length=out_len, 46 | random_state=42, 47 | likelihood=likelihoods[likelihood], 48 | pl_trainer_kwargs={ 49 | "accelerator": "gpu", 50 | "devices": [0], 51 | "max_steps": 10000, 52 | }, 53 | **kwargs 54 | ) 55 | 56 | scaled_train_ts_list = [] 57 | scaled_test_ts_list = [] 58 | scaled_combined_series_list = [] 59 | 60 | scaler = MinMaxScaler() 61 | 62 | # Concatenate all series and fit the scaler 63 | all_series = train + test 64 | combined = pd.concat(all_series) 65 | scaler.fit(combined.values.reshape(-1,1)) 66 | 67 | # Iterate over each series in the train list 68 | for train_series, test_series in zip(train,test): 69 | scaled_train_series = scaler.transform(train_series.values.reshape(-1,1)).reshape(-1) 70 | scaled_train_series_ts = TimeSeries.from_times_and_values(train_series.index, scaled_train_series) 71 | scaled_train_ts_list.append(scaled_train_series_ts) 72 | 73 | scaled_test_series = scaler.transform(test_series.values.reshape(-1,1)).reshape(-1) 74 | scaled_test_series_ts = TimeSeries.from_times_and_values(test_series.index, scaled_test_series) 75 | scaled_test_ts_list.append(scaled_test_series_ts) 76 | 77 | scaled_combined_series = scaler.transform(pd.concat([train_series,test_series]).values.reshape(-1,1)).reshape(-1) 78 | scaled_combined_series_list.append(scaled_combined_series) 79 | 80 | print('************ Fitting model... ************') 81 | if n_train is not None: 82 | model.fit(scaled_train_ts_list[:n_train], epochs=epochs) 83 | else: 84 | model.fit(scaled_train_ts_list, epochs=epochs) 85 | 86 | rescaled_predictions_list = [] 87 | BPD_list = [] 88 | samples_list = [] 89 | samples = None 90 | median = None 91 | 92 | with torch.no_grad(): 93 | predictions = None 94 | if num_samples > 0: 95 | print('************ Predicting... ************') 96 | predictions = model.predict(n=test_len, series=scaled_train_ts_list, num_samples=num_samples) 97 | for i in range(len(predictions)): 98 | prediction = predictions[i].data_array()[:,0,:].T.values 99 | rescaled_prediction = scaler.inverse_transform(prediction.reshape(-1,1)).reshape(num_samples,-1) 100 | samples = pd.DataFrame(rescaled_prediction, columns=test[i].index) 101 | rescaled_predictions_list.append(rescaled_prediction) 102 | samples_list.append(samples) 103 | samples = samples_list if len(samples_list)>1 else samples_list[0] 104 | median = [samples.median(axis=0) for samples in samples_list] if len(samples_list)>1 else samples_list[0].median(axis=0) 105 | print('************ Getting likelihood... ************') 106 | for i in range(len(scaled_combined_series_list)): 107 | BPD = get_chunked_AR_likelihoods(model,scaled_combined_series_list[i],len(train[i]),in_len,out_len,scaler) 108 | BPD_list.append(BPD) 109 | 110 | out_dict = { 111 | 'NLL/D': np.mean(BPD_list), 112 | 'samples': samples, 113 | 'median': median, 114 | 'info': {'Method':str(modeltype), 'epochs':epochs, 'out_len':out_len} 115 | } 116 | 117 | return out_dict 118 | 119 | def get_chunked_AR_likelihoods(model,scaled_series,trainsize,in_len,out_len,scaler): 120 | teacher_forced_inputs = torch.from_numpy(scaled_series[trainsize-in_len:][None,:,None]) 121 | testsize = len(scaled_series)-trainsize 122 | n = 0 123 | nll_sum = 0 124 | while n < testsize: 125 | inp = teacher_forced_inputs[:,n:n+in_len] 126 | elems_left = min(out_len, testsize-n) 127 | params = model.model((inp,None)) 128 | likelihood_params = params[:,-out_len:][:,:elems_left] 129 | likelihood_params2 = model.likelihood._params_from_output(likelihood_params) 130 | target = teacher_forced_inputs[:,in_len+n:in_len+n+elems_left] 131 | nll_sum += model.likelihood._nllloss(likelihood_params2,target).detach().numpy()*elems_left 132 | n += elems_left 133 | assert n == testsize 134 | nll_per_dimension = nll_sum/n 135 | nll_per_dimension -= np.log(scaler.scale_)#np.log(scaler._fitted_params[0].scale_) 136 | return nll_per_dimension.item() 137 | 138 | #from statsmodels.tsa.arima.model import ARIMA 139 | from statsmodels.tsa.arima.model import ARIMA as staARIMA 140 | import types 141 | 142 | def _new_arima_fit(self, series, future_covariates = None): 143 | super(darts.models.ARIMA,self)._fit(series, future_covariates) 144 | 145 | self._assert_univariate(series) 146 | 147 | # storing to restore the statsmodels model results object 148 | self.training_historic_future_covariates = future_covariates 149 | 150 | m = staARIMA( 151 | series.values(copy=False), 152 | exog=future_covariates.values(copy=False) if future_covariates else None, 153 | order=self.order, 154 | seasonal_order=self.seasonal_order, 155 | trend=self.trend, 156 | #initialization='approximate_diffuse', 157 | ) 158 | self.model = m.fit() 159 | 160 | return self 161 | 162 | def get_arima_predictions_data(train, test, p=12, d=1, q=0, num_samples=100, **kwargs): 163 | num_samples = max(num_samples, 1) 164 | if not isinstance(train, list): 165 | # assume single train/test case 166 | train = [train] 167 | test = [test] 168 | for i in range(len(train)): 169 | if not isinstance(train[i], pd.Series): 170 | train[i] = pd.Series(train[i], index = pd.RangeIndex(len(train[i]))) 171 | test[i] = pd.Series(test[i], index = pd.RangeIndex(len(train[i]),len(test[i])+len(train[i]))) 172 | 173 | test_len = len(test[0]) 174 | assert all(len(t)==test_len for t in test), f'All test series must have same length, got {[len(t) for t in test]}' 175 | 176 | model = darts.models.ARIMA(p=p, d=d, q=q) 177 | 178 | scaled_train_ts_list = [] 179 | scaled_test_ts_list = [] 180 | scaled_combined_series_list = [] 181 | scalers = [] 182 | 183 | 184 | # Iterate over each series in the train list 185 | for train_series, test_series in zip(train,test): 186 | # for ARIMA we scale each series individually 187 | scaler = MinMaxScaler() 188 | combined_series = pd.concat([train_series,test_series]) 189 | scaler.fit(combined_series.values.reshape(-1,1)) 190 | scalers.append(scaler) 191 | scaled_train_series = scaler.transform(train_series.values.reshape(-1,1)).reshape(-1) 192 | scaled_train_series_ts = TimeSeries.from_times_and_values(train_series.index, scaled_train_series) 193 | scaled_train_ts_list.append(scaled_train_series_ts) 194 | 195 | scaled_test_series = scaler.transform(test_series.values.reshape(-1,1)).reshape(-1) 196 | scaled_test_series_ts = TimeSeries.from_times_and_values(test_series.index, scaled_test_series) 197 | scaled_test_ts_list.append(scaled_test_series_ts) 198 | 199 | scaled_combined_series = scaler.transform(pd.concat([train_series,test_series]).values.reshape(-1,1)).reshape(-1) 200 | scaled_combined_series_list.append(scaled_combined_series) 201 | 202 | 203 | rescaled_predictions_list = [] 204 | nll_all_list = [] 205 | samples_list = [] 206 | 207 | for i in range(len(scaled_train_ts_list)): 208 | try: 209 | model.fit(scaled_train_ts_list[i]) 210 | prediction = model.predict(len(test[i]), num_samples=num_samples).data_array()[:,0,:].T.values 211 | scaler = scalers[i] 212 | rescaled_prediction = scaler.inverse_transform(prediction.reshape(-1,1)).reshape(num_samples,-1) 213 | fit_model = model.model.model.fit() 214 | fit_params = fit_model.conf_int().mean(1) 215 | all_model = staARIMA( 216 | scaled_combined_series_list[i], 217 | exog=None, 218 | order=model.order, 219 | seasonal_order=model.seasonal_order, 220 | trend=model.trend, 221 | ) 222 | nll_all = -all_model.loglikeobs(fit_params) 223 | nll_all = nll_all[len(train[i]):].sum()/len(test[i]) 224 | nll_all -= np.log(scaler.scale_) 225 | nll_all = nll_all.item() 226 | except np.linalg.LinAlgError: 227 | rescaled_prediction = np.zeros((num_samples,len(test[i]))) 228 | # output nan 229 | nll_all = np.nan 230 | 231 | samples = pd.DataFrame(rescaled_prediction, columns=test[i].index) 232 | 233 | rescaled_predictions_list.append(rescaled_prediction) 234 | nll_all_list.append(nll_all) 235 | samples_list.append(samples) 236 | 237 | out_dict = { 238 | 'NLL/D': np.mean(nll_all_list), 239 | 'samples': samples_list if len(samples_list)>1 else samples_list[0], 240 | 'median': [samples.median(axis=0) for samples in samples_list] if len(samples_list)>1 else samples_list[0].median(axis=0), 241 | 'info': {'Method':'ARIMA', 'p':p, 'd':d} 242 | } 243 | 244 | return out_dict 245 | 246 | -------------------------------------------------------------------------------- /utils_paragraph.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | import json 4 | with open('config.json', 'r', encoding='utf-8') as f: 5 | config = json.load(f) 6 | 7 | openai.api_key = config['OPENAI_API_KEY'] 8 | openai.api_base = config['OPENAI_API_BASE'] 9 | 10 | import pandas as pd 11 | from data1.small_context import get_datasets 12 | import re 13 | # from llama_utils import llama_api_qa 14 | from data1.serialize import SerializerSettings 15 | 16 | # Add situation description by each dataset 17 | 18 | model_select = 'gpt-4-1106-preview' 19 | gpt3_hypers = dict( 20 | temp=0.7, 21 | alpha=0.95, 22 | beta=0.3, 23 | basic=False, 24 | settings=SerializerSettings(base=10, prec=3, signed=True, half_bin_correction=True) 25 | ) 26 | 27 | gpt4_hypers = dict( 28 | alpha=0.3, 29 | basic=True, 30 | temp=0.5, 31 | top_p=0.5, 32 | settings=SerializerSettings(base=10, prec=3, signed=True, time_sep=', ', bit_sep='', minus_sign='-') 33 | ) 34 | 35 | hyper_gpt35 = f"Set hyperparameters to: temp={gpt3_hypers['temp']}, " \ 36 | f"alpha={gpt3_hypers['alpha']}, beta={gpt3_hypers['beta']}, " \ 37 | f"basic={gpt3_hypers['basic']}, settings={gpt3_hypers['settings']}" 38 | 39 | hyper_gpt4 = f"Set hyperparameters to: alpha={gpt4_hypers['alpha']}, " \ 40 | f"temp={gpt4_hypers['temp']}, top_p={gpt4_hypers['top_p']}, " \ 41 | f"basic={gpt4_hypers['basic']}, settings={gpt4_hypers['settings']}" 42 | def paraphrase_initial(data_name): 43 | if data_name == 'AirPassengersDataset': 44 | desp = "This is a series of monthly passenger numbers for international flights, " \ 45 | "where each value is in thousands of passengers for that month. " 46 | if data_name == 'AusBeerDataset': 47 | desp = "This is a quarterly series of beer production, where each value is " \ 48 | "the kilolitres of beer produced in that quarter. " 49 | if data_name == 'GasRateCO2Dataset': 50 | desp = "This is a time series dataset describing monthly carbon dioxide emissions. " 51 | if data_name == 'MonthlyMilkDataset': 52 | desp = "This is a time-series data set describing monthly milk production, " \ 53 | "Each number is the average number of tons of milk produced by each cow during the month. " 54 | if data_name == 'SunspotsDataset': 55 | desp = "This is a dataset that records the number of sunspots in each month, " \ 56 | "where each data is the number of sunspots in that month. " 57 | if data_name == 'WineDataset': 58 | desp = "This is a dataset of monthly wine production in Australia," \ 59 | "where each figure is the number of wine bottles produced in that month. " 60 | if data_name == 'WoolyDataset': 61 | desp = "This is an Australian yarn production for each quarter, " \ 62 | "where each value is how many tons of yarn were produced in that quarter. " 63 | if data_name == 'HeartRateDataset': 64 | desp = "The series contains 1800 uniformly spaced instantaneous " \ 65 | "heart rate measurements from a single subject. " 66 | 67 | return desp 68 | 69 | 70 | # Transfer Sequences to Natural Language. 71 | # seq: , des: String 72 | def paraphrase_seq2lan(seq, desp): 73 | results = '' 74 | # The values of the sequence are read one by one and a description is output 75 | for i in range(len(seq) - 1): 76 | t1 = seq.iloc[i] # Select elements by position 77 | t2 = seq.iloc[i + 1] # Select next elements by position 78 | result = describe_change(t1, t2) 79 | results += result 80 | lan = desp + results 81 | 82 | return lan 83 | 84 | 85 | def describe_change(t1, t2): 86 | if t2 > t1: 87 | return f"from {t1} increasing to {t2}, " 88 | elif t2 < t1: 89 | return f"from {t1} decreasing to {t2}, " 90 | else: 91 | return f"it remains flat from {t2} to {t1}," 92 | 93 | 94 | # Recover from language description to sequence 95 | def recover_lan2seq(input_string): 96 | # step 1: cut description 97 | dot_index = input_string.find('.') 98 | cleaned_string = input_string[dot_index + 1:].strip() if dot_index != -1 else input_string.strip() 99 | 100 | # Step 2: task numbers 101 | numbers = re.findall(r'(\d+\.\d+)', cleaned_string) 102 | # Transfer to list 103 | float_numbers = [float(num) for num in numbers] 104 | 105 | # Step 3: Kill the doubled numbers 106 | filtered_numbers = [float_numbers[i] for i in range(len(float_numbers)) if i % 2 == 0] 107 | # add the last one 108 | filtered_numbers.append(float_numbers[-1]) 109 | # recover to pandas Series 110 | result_series = pd.Series(filtered_numbers) 111 | 112 | return result_series 113 | 114 | 115 | def recover_lan2seq_llm(input_string): 116 | # Step 2: task numbers 117 | numbers = re.findall(r'(\d+\.\d+)', input_string) 118 | # Transfer to list 119 | float_numbers = [float(num) for num in numbers] 120 | 121 | # Step 3: Kill the doubled numbers 122 | filtered_numbers = [float_numbers[i] for i in range(len(float_numbers)) if i % 2 == 0] 123 | # add the last one 124 | filtered_numbers.append(float_numbers[-1]) 125 | # recover to pandas Series 126 | result_series = pd.Series(filtered_numbers) 127 | 128 | return result_series 129 | 130 | 131 | def paraphrase_nlp(datasets_list): 132 | # Train_lan: the paraphrased train sequence 133 | # Test_lan: the paraphrased test sequence 134 | # seq_test: pandas sequence of test 135 | datasets = get_datasets() 136 | for dataset_name in datasets_list: 137 | desp = paraphrase_initial(dataset_name) 138 | data = datasets[dataset_name] 139 | train, test = data 140 | print("Train len:", train.shape) 141 | print("test len:", test.shape) 142 | Train_lan = paraphrase_seq2lan(train, desp) 143 | Test_lan = paraphrase_seq2lan(test, desp) 144 | seq_test = recover_lan2seq(Test_lan) 145 | print("seq pred len:", seq_test.shape) 146 | if test.shape != seq_test.shape: 147 | print("Warning! The data lost!") 148 | 149 | return Train_lan, Test_lan, seq_test 150 | 151 | 152 | def paraphrase_llm(datasets_list): 153 | prompt = " analyze this time series and rewrite it" \ 154 | " as a trend-by-trend representation of discrete values. Only numerical " \ 155 | "changes are described, not date changes. For example, the template like {from 1.0 increasing to 2.0, " \ 156 | "from 2.0 decreasing to 0.5,} Be careful not to lose every sequence value. " 157 | datasets = get_datasets() 158 | 159 | for dataset_name in datasets_list: 160 | desp = paraphrase_initial(dataset_name) 161 | data = datasets[dataset_name] 162 | train, test = data 163 | content_train = "You are a useful assistant," + desp + str(train) 164 | content_test = "You are a useful assistant," + desp + str(test) 165 | response = openai.ChatCompletion.create( 166 | model=model_select, 167 | response_format={"type": "text"}, 168 | messages=[ 169 | {"role": "system", "content": prompt}, 170 | {"role": "user", "content": content_train} 171 | ] 172 | ) 173 | Train_lan = response.choices[0].message.content 174 | response = openai.ChatCompletion.create( 175 | model=model_select, 176 | response_format={"type": "text"}, 177 | messages=[ 178 | {"role": "system", "content": prompt}, 179 | {"role": "user", "content": content_test} 180 | ] 181 | ) 182 | Test_lan = response.choices[0].message.content 183 | print("Test_lan:", Test_lan) 184 | seq_test = recover_lan2seq_llm(Test_lan) 185 | if test.shape != seq_test.shape: 186 | print("the process error!") 187 | print("seq_test.shape:", seq_test.shape) 188 | print("test.shape:", test.shape) 189 | 190 | return Train_lan, Test_lan, seq_test 191 | 192 | def paraphrasing_predict_llm(desp, train_lan, steps, model_name): 193 | if model_name == 'gpt-3.5-turbo-0125': 194 | hyper_parameters_message = hyper_gpt35 195 | else: 196 | hyper_parameters_message = hyper_gpt4 197 | prompt = "You are a helpful assistant that performs time series predictions. " \ 198 | "The user will provide a sequence and you will predict the remaining sequence." \ 199 | "The sequence is represented by decimal strings separated by commas. " \ 200 | "Please continue the following sequence without producing any additional text. " \ 201 | "Do not say anything like 'the next terms in the sequence are', just return the numbers." 202 | # prompt_add = f"Please predict ahead in {steps} steps, one step means (from 1.0 increasing to 2.0,) or" \ 203 | # "(from 2.0 decreasing to 0.5,), The final output follows exactly steps. Sequence:\n" 204 | prompt_add = (f"Predict the next {steps} steps, where each step follows the format (starting from 1.0 and increasing to 2.0) or (starting from 2.0 and decreasing to 0.5)."\ 205 | " The final output should precisely follow the specified number of steps. Provide a sequence:\n") 206 | 207 | content_train = prompt + desp + prompt_add + train_lan 208 | response = openai.ChatCompletion.create( 209 | model=model_name, 210 | response_format={"type": "text"}, 211 | messages=[ 212 | {"role": "system", "content": hyper_parameters_message}, 213 | {"role": "system", "content": prompt}, 214 | {"role": "user", "content": content_train} 215 | ] 216 | ) 217 | Test_lan = response.choices[0].message.content 218 | # print("Test_lan:", Test_lan) 219 | seq_test = recover_lan2seq_llm(Test_lan) 220 | 221 | return seq_test 222 | 223 | 224 | def paraphrasing_predict_llama(desp, train_lan, steps, model_name): 225 | # prompt = "You are a helpful assistant that performs time series predictions. " \ 226 | # "The user will provide a sequence and you will predict the remaining sequence." \ 227 | # "The sequence is represented by decimal strings separated by commas. " \ 228 | # "Please continue the following sequence without producing any additional text. " \ 229 | # "Do not say anything like 'the next terms in the sequence are', just return the numbers." 230 | prompt_add = f"Please predict ahead in {steps} steps, one step means (from 1.0 increasing to 2.0,) or" \ 231 | "(from 2.0 decreasing to 0.5,), The final output follows exactly steps. Sequence:\n" 232 | content_train = desp + prompt_add + train_lan 233 | response = llama_api_qa(model_name, content_train) 234 | seq_test = recover_lan2seq_llm(response) 235 | 236 | return seq_test 237 | 238 | 239 | # test main 240 | if __name__ == '__main__': 241 | # initial 242 | datasets_list = [ 243 | 'AirPassengersDataset', 244 | 'AusBeerDataset', 245 | 'GasRateCO2Dataset', 246 | 'MonthlyMilkDataset', 247 | 'SunspotsDataset', 248 | 'WineDataset', 249 | 'WoolyDataset', 250 | 'HeartRateDataset', 251 | ] 252 | 253 | # traditional 254 | paraphrase_nlp(datasets_list) 255 | 256 | # LLM 257 | paraphrase_llm(datasets_list) 258 | -------------------------------------------------------------------------------- /models/llmtime.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from data1.serialize import serialize_arr, deserialize_str, SerializerSettings 3 | from concurrent.futures import ThreadPoolExecutor 4 | import numpy as np 5 | import pandas as pd 6 | from dataclasses import dataclass 7 | from models.llms import completion_fns, nll_fns, tokenization_fns, context_lengths 8 | 9 | STEP_MULTIPLIER = 1.2 10 | 11 | @dataclass 12 | class Scaler: 13 | """ 14 | Represents a data1 scaler with transformation and inverse transformation functions. 15 | 16 | Attributes: 17 | transform (callable): Function to apply transformation. 18 | inv_transform (callable): Function to apply inverse transformation. 19 | """ 20 | transform: callable = lambda x: x 21 | inv_transform: callable = lambda x: x 22 | 23 | def get_scaler(history, alpha=0.95, beta=0.3, basic=False): 24 | """ 25 | Generate a Scaler object based on given history data1. 26 | 27 | Args: 28 | history (array-like): Data to derive scaling from. 29 | alpha (float, optional): Quantile for scaling. Defaults to .95. 30 | # Truncate inputs 31 | tokens = [tokeniz] 32 | beta (float, optional): Shift parameter. Defaults to .3. 33 | basic (bool, optional): If True, no shift is applied, and scaling by values below 0.01 is avoided. Defaults to False. 34 | 35 | Returns: 36 | Scaler: Configured scaler object. 37 | """ 38 | history = history[~np.isnan(history)] 39 | if basic: 40 | q = np.maximum(np.quantile(np.abs(history), alpha),.01) 41 | def transform(x): 42 | return x / q 43 | def inv_transform(x): 44 | return x * q 45 | else: 46 | min_ = np.min(history) - beta*(np.max(history)-np.min(history)) 47 | q = np.quantile(history-min_, alpha) 48 | if q == 0: 49 | q = 1 50 | def transform(x): 51 | return (x - min_) / q 52 | def inv_transform(x): 53 | return x * q + min_ 54 | return Scaler(transform=transform, inv_transform=inv_transform) 55 | 56 | def truncate_input(input_arr, input_str, settings, model, steps): 57 | """ 58 | Truncate inputs to the maximum context length for a given model. 59 | 60 | Args: 61 | input (array-like): input time series. 62 | input_str (str): serialized input time series. 63 | settings (SerializerSettings): Serialization settings. 64 | model (str): Name of the LLM model to use. 65 | steps (int): Number of steps to predict. 66 | Returns: 67 | tuple: Tuple containing: 68 | - input (array-like): Truncated input time series. 69 | - input_str (str): Truncated serialized input time series. 70 | """ 71 | # 此处gemini采用数字中间无空格的方式 (此处暂时如此设定) 72 | if model in tokenization_fns and model in context_lengths: 73 | tokenization_fn = tokenization_fns[model] 74 | context_length = context_lengths[model] 75 | input_str_chuncks = input_str.split(settings.time_sep) 76 | for i in range(len(input_str_chuncks) - 1): 77 | truncated_input_str = settings.time_sep.join(input_str_chuncks[i:]) 78 | # add separator if not already present 79 | if not truncated_input_str.endswith(settings.time_sep): 80 | truncated_input_str += settings.time_sep 81 | input_tokens = tokenization_fn(truncated_input_str) 82 | num_input_tokens = len(input_tokens) 83 | avg_token_length = num_input_tokens / (len(input_str_chuncks) - i) 84 | num_output_tokens = avg_token_length * steps * STEP_MULTIPLIER 85 | if num_input_tokens + num_output_tokens <= context_length: 86 | truncated_input_arr = input_arr[i:] 87 | break 88 | if i > 0: 89 | print(f'Warning: Truncated input from {len(input_arr)} to {len(truncated_input_arr)}') 90 | return truncated_input_arr, truncated_input_str 91 | else: 92 | return input_arr, input_str 93 | 94 | def handle_prediction(pred, expected_length, strict=False): 95 | """ 96 | Process the output from LLM after deserialization, which may be too long or too short, or None if deserialization failed on the first prediction step. 97 | 98 | Args: 99 | pred (array-like or None): The predicted values. None indicates deserialization failed. 100 | expected_length (int): Expected length of the prediction. 101 | strict (bool, optional): If True, returns None for invalid predictions. Defaults to False. 102 | 103 | Returns: 104 | array-like: Processed prediction. 105 | """ 106 | if pred is None: 107 | return None 108 | else: 109 | if len(pred) < expected_length: 110 | if strict: 111 | print(f'Warning: Prediction too short {len(pred)} < {expected_length}, returning None') 112 | return None 113 | else: 114 | print(f'Warning: Prediction too short {len(pred)} < {expected_length}, padded with last value') 115 | return np.concatenate([pred, np.full(expected_length - len(pred), pred[-1])]) 116 | else: 117 | return pred[:expected_length] 118 | 119 | def generate_predictions( 120 | completion_fn, 121 | input_strs, 122 | steps, 123 | settings: SerializerSettings, 124 | scalers: None, 125 | num_samples=1, 126 | temp=0.7, 127 | parallel=True, 128 | strict_handling=False, 129 | max_concurrent=10, 130 | whether_blanket=True, 131 | genai_key=None, 132 | **kwargs 133 | ): 134 | """ 135 | Generate and process text completions from a language model for input time series. 136 | 137 | Args: 138 | completion_fn (callable): Function to obtain text completions from the LLM. 139 | input_strs (list of array-like): List of input time series. 140 | steps (int): Number of steps to predict. 141 | settings (SerializerSettings): Settings for serialization. 142 | scalers (list of Scaler, optional): List of Scaler objects. Defaults to None, meaning no scaling is applied. 143 | num_samples (int, optional): Number of samples to return. Defaults to 1. 144 | temp (float, optional): Temperature for sampling. Defaults to 0.7. 145 | parallel (bool, optional): If True, run completions in parallel. Defaults to True. 146 | strict_handling (bool, optional): If True, return None for predictions that don't have exactly the right format or expected length. Defaults to False. 147 | max_concurrent (int, optional): Maximum number of concurrent completions. Defaults to 50. 148 | **kwargs: Additional keyword arguments. 149 | 150 | Returns: 151 | tuple: Tuple containing: 152 | - preds (list of lists): Numerical predictions. 153 | - completions_list (list of lists): Raw text completions. 154 | - input_strs (list of str): Serialized input strings. 155 | """ 156 | 157 | completions_list = [] 158 | complete = lambda x: completion_fn(input_str=x, steps=steps*STEP_MULTIPLIER, settings=settings, num_samples=num_samples, temp=temp, whether_blanket=whether_blanket, genai_key=genai_key) 159 | # 位置 160 | if parallel and len(input_strs) > 1: 161 | print('Running completions in parallel for each input') 162 | with ThreadPoolExecutor(min(max_concurrent, len(input_strs))) as p: 163 | completions_list = list(tqdm(p.map(complete, input_strs), total=len(input_strs))) 164 | else: 165 | completions_list = [complete(input_str) for input_str in tqdm(input_strs)] 166 | def completion_to_pred(completion, inv_transform): 167 | pred = handle_prediction(deserialize_str(completion, settings, ignore_last=False, steps=steps), expected_length=steps, strict=strict_handling) 168 | if pred is not None: 169 | return inv_transform(pred) 170 | else: 171 | return None 172 | preds = [[completion_to_pred(completion, scaler.inv_transform) for completion in completions] for completions, scaler in zip(completions_list, scalers)] 173 | return preds, completions_list, input_strs 174 | 175 | def get_llmtime_predictions_data(train, test, model, settings, num_samples=10, temp=0.7, alpha=0.95, beta=0.3, basic=False, parallel=True, whether_blanket=True, genai_key=None, **kwargs): 176 | """ 177 | Obtain forecasts from an LLM based on training series (history) and evaluate likelihood on test series (true future). 178 | train and test can be either a single time series or a list of time series. 179 | 180 | Args: 181 | train (array-like or list of array-like): Training time series data1 (history). 182 | test (array-like or list of array-like): Test time series data1 (true future). 183 | model (str): Name of the LLM model to use. Must have a corresponding entry in completion_fns. 184 | settings (SerializerSettings or dict): Serialization settings. 185 | num_samples (int, optional): Number of samples to return. Defaults to 10. 186 | temp (float, optional): Temperature for sampling. Defaults to 0.7. 187 | alpha (float, optional): Scaling parameter. Defaults to 0.95. 188 | beta (float, optional): Shift parameter. Defaults to 0.3. 189 | basic (bool, optional): If True, use the basic version of data1 scaling. Defaults to False. 190 | parallel (bool, optional): If True, run predictions in parallel. Defaults to True. 191 | **kwargs: Additional keyword arguments. 192 | 193 | Returns: 194 | dict: Dictionary containing predictions, samples, median, NLL/D averaged over each series, and other related information. 195 | """ 196 | # 这里不是直接预测,还必须在原代码里有阐述(有趣, 可以说是形成了体系) 197 | assert model in completion_fns, f'Invalid model {model}, must be one of {list(completion_fns.keys())}' 198 | completion_fn = completion_fns[model] 199 | # 这里对于使用的大模型需要放进去 200 | nll_fn = nll_fns[model] if model in nll_fns else None 201 | # 这里可以无视 (以gpt模型作为参考) 202 | 203 | if isinstance(settings, dict): 204 | settings = SerializerSettings(**settings) # 将其字典化 205 | if not isinstance(train, list): 206 | # Assume single train/test case 207 | train = [train] 208 | test = [test] 209 | 210 | for i in range(len(train)): 211 | if not isinstance(train[i], pd.Series): 212 | train[i] = pd.Series(train[i], index=pd.RangeIndex(len(train[i]))) 213 | test[i] = pd.Series(test[i], index=pd.RangeIndex(len(train[i]), len(test[i])+len(train[i]))) 214 | 215 | test_len = len(test[0]) 216 | assert all(len(t)==test_len for t in test), f'All test series must have same length, got {[len(t) for t in test]}' 217 | 218 | # Create a unique scaler for each series 219 | scalers = [get_scaler(train[i].values, alpha=alpha, beta=beta, basic=basic) for i in range(len(train))] 220 | # 此处是超参数寻优的点 221 | 222 | # transform input_arrs 223 | input_arrs = [train[i].values for i in range(len(train))] 224 | transformed_input_arrs = np.array([scaler.transform(input_array) for input_array, scaler in zip(input_arrs, scalers)]) 225 | # serialize input_arrs 226 | input_strs = [serialize_arr(scaled_input_arr, settings) for scaled_input_arr in transformed_input_arrs] 227 | # Truncate input_arrs to fit the maximum context length 228 | # 这里有truncate操作,可能和gpt不同 229 | input_arrs, input_strs = zip(*[truncate_input(input_array, input_str, settings, model, test_len) for input_array, input_str in zip(input_arrs, input_strs)]) 230 | # 此处 input_str 是tuple, 每个元素是一个数据集 231 | 232 | steps = test_len 233 | samples = None 234 | medians = None 235 | completions_list = None 236 | if num_samples > 0: 237 | preds, completions_list, input_strs = generate_predictions(completion_fn, input_strs, steps, settings, scalers, 238 | num_samples=num_samples, temp=temp, 239 | parallel=parallel, whether_blanket=whether_blanket, genai_key=genai_key, **kwargs) 240 | # 位置 241 | samples = [pd.DataFrame(preds[i], columns=test[i].index) for i in range(len(preds))] 242 | medians = [sample.median(axis=0) for sample in samples] 243 | samples = samples if len(samples) > 1 else samples[0] 244 | medians = medians if len(medians) > 1 else medians[0] 245 | out_dict = { 246 | 'samples': samples, 247 | 'median': medians, 248 | 'info': { 249 | 'Method': model, 250 | }, 251 | 'completions_list': completions_list, 252 | 'input_strs': input_strs, 253 | } 254 | # Compute NLL/D on the true test series conditioned on the (truncated) input series 255 | if nll_fn is not None: 256 | BPDs = [nll_fn(input_arr=input_arrs[i], target_arr=test[i].values, settings=settings, transform=scalers[i].transform, count_seps=True, temp=temp) for i in range(len(train))] 257 | out_dict['NLL/D'] = np.mean(BPDs) 258 | return out_dict 259 | -------------------------------------------------------------------------------- /models/promptcast.py: -------------------------------------------------------------------------------- 1 | from data1.metrics import Evaluator 2 | from tqdm import tqdm 3 | from multiprocess import Pool 4 | from functools import partial 5 | import tiktoken 6 | from functools import partial 7 | from data1.serialize import serialize_arr, deserialize_str, SerializerSettings 8 | import openai 9 | from concurrent.futures import ThreadPoolExecutor 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | from data1.metrics import nll 13 | import pandas as pd 14 | from dataclasses import dataclass 15 | 16 | @dataclass 17 | class Scaler: 18 | transform: callable = lambda x: x 19 | inv_transform: callable = lambda x: x 20 | 21 | def get_scaler(history, alpha=.9, beta=.3,basic=False): 22 | # shift (min - beta*(max-min)) to 0 23 | # then scale alpha quantile to 1 24 | # alpha = -1 means no scaling 25 | history = history[~np.isnan(history)] 26 | min_ = np.min(history) - beta*(np.max(history)-np.min(history)) 27 | if basic: 28 | q = np.maximum(np.quantile(np.abs(history), alpha),.01) 29 | # scale so that alpha fraction of values are within [0, 1] 30 | def transform(x): 31 | return x / q 32 | def inv_transform(x): 33 | return x * q 34 | return Scaler(transform=transform, inv_transform=inv_transform) 35 | if alpha == -1: 36 | q = 1 37 | else: 38 | q = np.quantile(history-min_, alpha) 39 | if q == 0: 40 | q = 1 41 | # scale so that alpha fraction of values are within [0, 1] 42 | def transform(x): 43 | return (x - min_) / q 44 | def inv_transform(x): 45 | return x * q + min_ 46 | return Scaler(transform=transform, inv_transform=inv_transform) 47 | 48 | def get_token_ids(tokens, model, input_string): 49 | encoding = tiktoken.encoding_for_model(model) 50 | ids = [] 51 | for t in tokens: 52 | id = encoding.encode(t) 53 | if len(id) != 1: 54 | for i in id: 55 | ids.append(i) 56 | #raise ValueError(f'{t} is not a single token') 57 | else: 58 | ids.append(id[0]) 59 | return ids 60 | 61 | def get_avg_tokens_per_step(input_str,settings): 62 | input_tokens = sum([1 + len(x) / 2 for x in input_str.split(settings.time_sep)]) # add 1 for the comma, divide by 2 for the space 63 | input_steps = len(input_str.split(settings.time_sep)) 64 | tokens_per_step = input_tokens / input_steps 65 | return tokens_per_step 66 | 67 | def truncate(train, test, scaler, model, settings): 68 | tokens_perstep = get_avg_tokens_per_step( 69 | serialize_arr( 70 | scaler.transform(pd.concat([train,test]).values), 71 | settings 72 | ), 73 | settings 74 | ) 75 | if model == 'gpt-4': 76 | max_tokens=6000 77 | elif model == 'gpt-3.5-turbo': 78 | max_tokens = 4000 79 | else: 80 | max_tokens = 2000 81 | 82 | # 1.3 accounts for overhead in sampling 83 | if 1.35*tokens_perstep*(len(train)+len(test)) > max_tokens: 84 | total_timestep_budget = int(max_tokens/tokens_perstep) 85 | full_train_len = len(train) 86 | for num_try in range(10): 87 | sub_train = train.iloc[-(total_timestep_budget-len(test)):] 88 | if 1.35*tokens_perstep*(len(sub_train)+len(test)) <= max_tokens: 89 | train = sub_train 90 | print(f"Truncated train to {full_train_len} --> {len(train)} timesteps") 91 | break 92 | total_timestep_budget = int(0.8 * total_timestep_budget) 93 | else: 94 | raise ValueError(f"After truncation, dataset is still too large for GPT-3, 1.3 * {tokens_perstep} * ({len(sub_train)} + {len(test)}) = {1.3*tokens_perstep*(len(sub_train)+len(test))} > {max_tokens}") 95 | return train 96 | 97 | def sample_completions(model, input_str, steps, settings, num_samples, temp, logit_bias,**kwargs): 98 | ''' Sample completions from GPT-3 99 | Args: 100 | input_str: input sequence as a string 101 | steps: number of steps to predict 102 | precision: number of bits to use for encoding 103 | num_samples: number of samples to return 104 | temp: temperature for sampling 105 | prompt: additional prompt before the input string 106 | model: name of GPT-3 model to use 107 | Returns: 108 | list of completion strings 109 | ''' 110 | # estimate avg number of tokens per step 111 | tokens_per_step = get_avg_tokens_per_step(input_str,settings) 112 | steps = int(steps * 1.3) # add some overhead to account for the fact that we don't know the exact number of tokens per step 113 | # if not logit_bias: 114 | # input_str = input_str[:-1] 115 | if model in ['gpt-3.5-turbo','gpt-4']: 116 | chatgpt_sys_message = "You are a helpful assistant that performs time series predictions. The user will provide a sequence and you will predict the remaining sequence. The sequence is represented by decimal strings separated by commas." 117 | extra_input = "Please continue the following sequence without producing any additional text. Do not say anything like 'the next terms in the sequence are', just return the numbers. Sequence:\n" 118 | response = openai.ChatCompletion.create( 119 | model=model, 120 | messages=[ 121 | {"role": "system", "content": chatgpt_sys_message}, 122 | {"role": "user", "content": extra_input+input_str+settings.time_sep} 123 | ], 124 | max_tokens=int(tokens_per_step*steps), 125 | temperature=temp, 126 | logit_bias=logit_bias, 127 | n=num_samples, 128 | **kwargs 129 | ) 130 | return [choice.message.content for choice in response.choices] 131 | else: 132 | response = openai.Completion.create( 133 | model=model, 134 | prompt=input_str, 135 | max_tokens=int(tokens_per_step*steps), 136 | temperature=temp, 137 | logit_bias=logit_bias, 138 | n=num_samples 139 | ) 140 | return [choice.text for choice in response.choices] 141 | 142 | def handle_prediction(input, pred, expected_length, strict=False): 143 | ''' Handle prediction with expected length of expected_length. 144 | Useful for handling predictions that can't be deserialized or are too short or too long. 145 | ''' 146 | if strict: 147 | # must be a valid array (not None) and have the enough entries 148 | if pred is None or len(pred) < expected_length: 149 | print('Found invalid prediction') 150 | return None 151 | else: 152 | return pred[:expected_length] 153 | else: 154 | if pred is None: 155 | print('Warning: prediction failed to be deserialized, replaced with last value') 156 | return np.full(expected_length, input[-1]) 157 | elif len(pred) < expected_length: 158 | print(f'Warning: Prediction too short {len(pred)} < {expected_length}, padded with last value') 159 | return np.concatenate([pred, np.full(expected_length - len(pred), pred[-1])]) 160 | elif len(pred) > expected_length: 161 | return pred[:expected_length] 162 | else: 163 | return pred 164 | 165 | def generate_predictions( 166 | model, 167 | inputs, 168 | steps, 169 | settings: SerializerSettings, 170 | scalers: None, 171 | num_samples=1, 172 | temp=0.3, 173 | prompts=None, 174 | post_prompts=None, 175 | parallel=True, 176 | return_input_strs=False, 177 | constrain_tokens=True, 178 | strict_handling=False, 179 | **kwargs, 180 | ): 181 | ''' Generate predictions from GPT-3 for a batch of inputs by calling sample_completions 182 | Args: 183 | inputs: np float array of shape (batch_size, history_len) 184 | steps: number of steps to predict 185 | precision: number of bits to use for encoding 186 | num_samples: number of samples to return 187 | temp: temperature for sampling 188 | prompt: None or a batch of additional prompts before the input string 189 | post_prompt: None or a batch of additional prompts after the input string (e.g. for promptcast) 190 | model: name of GPT-3 model to use 191 | Returns: 192 | np float array of shape (batch_size, num_samples, steps) 193 | ''' 194 | if prompts is None: 195 | prompts = [''] * len(inputs) 196 | if post_prompts is None: 197 | post_prompts = [''] * len(inputs) 198 | assert len(prompts) == len(inputs), f'Number of prompts must match number of inputs, got {len(prompts)} prompts and {len(inputs)} inputs' 199 | assert len(post_prompts) == len(inputs), f'Number of post prompts must match number of inputs, got {len(post_prompts)} post prompts and {len(inputs)} inputs' 200 | 201 | if scalers is None: 202 | scalers = [Scaler() for _ in inputs] 203 | else: 204 | assert len(scalers) == len(inputs), 'Number of scalers must match number of inputs' 205 | 206 | transformed_inputs = np.array([scaler.transform(input_array) for input_array, scaler in zip(inputs, scalers)]) 207 | input_strs = [serialize_arr(scaled_input_array, settings) for scaled_input_array in transformed_inputs] 208 | if post_prompts[0] != '': 209 | # removing last time separator for promptcast 210 | input_strs = [prompt + input_str.rstrip(settings.time_sep) + post_prompt for input_str, prompt, post_prompt in zip(input_strs, prompts, post_prompts)] 211 | else: 212 | input_strs = [prompt + input_str for input_str, prompt in zip(input_strs, prompts)] 213 | allowed_tokens = [settings.bit_sep + str(i) for i in range(settings.base)] 214 | allowed_tokens += [settings.time_sep, settings.plus_sign, settings.minus_sign] 215 | allowed_tokens = [t for t in allowed_tokens if len(t) > 0] # remove empty tokens like an implicit plus sign 216 | logit_bias = {} 217 | if (model not in ['gpt-3.5-turbo','gpt-4']) and constrain_tokens: # logit bias not supported for chat models 218 | logit_bias = {id: 30 for id in get_token_ids(allowed_tokens, model,input_strs[0])} 219 | if not constrain_tokens: 220 | logit_bias = {id: 5 for id in get_token_ids(allowed_tokens, model,input_strs[0])} 221 | 222 | completions_list = [] 223 | complete = lambda x: sample_completions(model, x, steps, settings, num_samples, temp, logit_bias,**kwargs) 224 | if parallel and len(inputs) > 1: 225 | with ThreadPoolExecutor(len(inputs)) as p: 226 | completions_list = list(tqdm(p.map(complete, input_strs), total=len(inputs))) 227 | else: 228 | completions_list = [complete(input_str) for input_str in tqdm(input_strs)] 229 | # print(completions_list) 230 | def completion_to_pred(completion, transformed_input, inv_transform): 231 | pred = handle_prediction(transformed_input, deserialize_str(completion, settings, ignore_last=False, steps=steps), expected_length=steps, strict=strict_handling) 232 | if pred is not None: 233 | return inv_transform(pred) 234 | else: 235 | return None 236 | preds = [[completion_to_pred(completion, transformed_input, scaler.inv_transform) for completion in completions] for completions, transformed_input, scaler in zip(completions_list, transformed_inputs, scalers)] 237 | if return_input_strs: 238 | return preds, completions_list, input_strs 239 | return preds, completions_list 240 | 241 | def get_promptcast_predictions_data(train, test, model, settings, num_samples=10, temp=0.8, dataset_name='dataset', **kwargs): 242 | if isinstance(settings, dict): 243 | settings = SerializerSettings(**settings) 244 | if not isinstance(train, list): 245 | # Assume single train/test case 246 | train = [train] 247 | test = [test] 248 | 249 | for i in range(len(train)): 250 | if not isinstance(train[i], pd.Series): 251 | train[i] = pd.Series(train[i], index=pd.RangeIndex(len(train[i]))) 252 | test[i] = pd.Series(test[i], index=pd.RangeIndex(len(train[i]), len(test[i])+len(train[i]))) 253 | 254 | test_len = len(test[0]) 255 | assert all(len(t)==test_len for t in test), f'All test series must have same length, got {[len(t) for t in test]}' 256 | 257 | # Identity scalers 258 | scalers = [Scaler() for _ in range(len(train))] 259 | 260 | for i in range(len(train)): 261 | train[i] = truncate(train[i], test[i], scalers[i], model, settings) 262 | 263 | prompt = f'The values in the {dataset_name} for the past {len(train[0])} time steps are ' 264 | prompts = [prompt] * len(train) 265 | post_prompt = f'. What will the values for the next {len(test[0])} time steps will be? The values for the next {len(test[0])} time steps will be ' 266 | post_prompts = [post_prompt] * len(train) 267 | 268 | # Create inputs for GPT model 269 | inputs = [train[i].values for i in range(len(train))] 270 | steps = test_len 271 | 272 | samples = None 273 | medians = None 274 | completions_list = None 275 | input_strs = None 276 | if num_samples > 0: 277 | # Generate predictions 278 | preds, completions_list, input_strs = generate_predictions(model, inputs, steps, settings, scalers, 279 | num_samples=num_samples, temp=temp, prompts=prompts, post_prompts=post_prompts, 280 | parallel=True, return_input_strs=True, constrain_tokens=False, strict_handling=True, **kwargs) 281 | # skip bad samples 282 | samples = [pd.DataFrame(np.array([p for p in preds[i] if p is not None]), columns=test[i].index) for i in range(len(preds))] 283 | medians = [sample.median(axis=0) for sample in samples] 284 | samples = samples if len(samples) > 1 else samples[0] 285 | print('Got %d properly formatted samples' % len(samples)) 286 | medians = medians if len(medians) > 1 else medians[0] 287 | out_dict = { 288 | 'samples': samples, 289 | 'median': medians, 290 | 'info': { 291 | 'Method': model, 292 | }, 293 | 'completions_list': completions_list, 294 | 'input_strs': input_strs, 295 | } 296 | 297 | out_dict['NLL/D'] = None 298 | 299 | return out_dict -------------------------------------------------------------------------------- /utils_others.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import time 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | from pandas.plotting import register_matplotlib_converters 6 | register_matplotlib_converters() 7 | import numpy as np 8 | import matplotlib.font_manager as fm 9 | import seaborn as sns 10 | import os 11 | 12 | os.environ['OMP_NUM_THREADS'] = '4' 13 | 14 | from models.utils import grid_iter 15 | from models.promptcast import get_promptcast_predictions_data 16 | from models.darts import get_arima_predictions_data 17 | from models.validation_likelihood_tuning import get_autotuned_predictions_data 18 | from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score 19 | from models import llmtime 20 | from models.llmtime import get_llmtime_predictions_data 21 | 22 | import pathlib 23 | import textwrap 24 | import google.generativeai as genai 25 | 26 | import os 27 | os.environ['OMP_NUM_THREADS'] = '4' 28 | 29 | from data1.serialize import SerializerSettings 30 | from sklearn import metrics 31 | from data1.small_context import get_datasets, get_memorization_datasets, get_dataset 32 | from models.validation_likelihood_tuning import get_autotuned_predictions_data 33 | from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score 34 | import warnings 35 | warnings.filterwarnings("ignore", category=DeprecationWarning) 36 | 37 | 38 | 39 | def plot_preds_w_train_test(train, test, pred_dict, model_name, ds_name, dir='prediction_w_gemini\prediction_w_train+test', show_samples=False): 40 | """ 41 | Plot predictions with confidence intervals. (Contain both training and test set) 42 | 43 | Parameters: 44 | train (pd.Series): Time series of training data. 45 | test (pd.Series): Time series of testing data (ground truth). 46 | pred_dict (dict): Dictionary containing predictions and other metrics. 47 | model_name (str): Name of the predictive model. 48 | ds_name (str): Name of the dataset. 49 | show_samples (bool): Whether to plot individual samples along with predictions. 50 | 51 | Returns: 52 | None 53 | """ 54 | pred = pred_dict['median'] 55 | pred = pd.Series(pred, index=test.index) 56 | plt.figure(figsize=(8, 6), dpi=100) 57 | plt.plot(train) 58 | plt.plot(test, label='Truth', color='black') 59 | plt.plot(pred, label=model_name, color='purple') 60 | # Shade 90% confidence interval 61 | samples = pred_dict['samples'] 62 | lower = np.quantile(samples, 0.05, axis=0) 63 | upper = np.quantile(samples, 0.95, axis=0) 64 | plt.fill_between(pred.index, lower, upper, alpha=0.3, color='purple') 65 | if show_samples: 66 | samples = pred_dict['samples'] 67 | # Convert DataFrame to numpy array 68 | samples = samples.values if isinstance(samples, pd.DataFrame) else samples 69 | for i in range(min(10, samples.shape[0])): 70 | plt.plot(pred.index, samples[i], color='purple', alpha=0.3, linewidth=1) 71 | plt.legend(loc='upper left') 72 | if 'NLL/D' in pred_dict: 73 | nll = pred_dict['NLL/D'] 74 | if nll is not None: 75 | plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes, 76 | bbox=dict(facecolor='white', alpha=0.5)) 77 | plt.show() 78 | plt.savefig(dir + f'/{ds_name}_{model_name}_prediction.pdf', format='pdf') 79 | 80 | 81 | def plot_preds_w_test(test, pred_dict, model_name, ds_name, dir='prediction_w_gemini\prediction_w_test', show_samples=False): 82 | """ 83 | Plot predictions with confidence intervals, without training data. (Contain only the test set) 84 | 85 | Parameters: 86 | train (pd.Series): Time series of training data (not plotted). 87 | test (pd.Series): Time series of testing data (ground truth). 88 | pred_dict (dict): Dictionary containing predictions and other metrics. 89 | model_name (str): Name of the predictive model. 90 | ds_name (str): Name of the dataset. 91 | show_samples (bool): Whether to plot individual samples along with predictions. 92 | 93 | Returns: 94 | None 95 | """ 96 | pred = pred_dict['median'] 97 | pred = pd.Series(pred, index=test.index) 98 | plt.figure(figsize=(8, 6), dpi=100) 99 | # Omit plotting training data 100 | # plt.plot(train) 101 | plt.plot(test, label='Truth', color='black') 102 | plt.plot(pred, label=model_name, color='purple') 103 | # Shade 90% confidence interval 104 | samples = pred_dict['samples'] 105 | lower = np.quantile(samples, 0.05, axis=0) 106 | upper = np.quantile(samples, 0.95, axis=0) 107 | plt.fill_between(pred.index, lower, upper, alpha=0.3, color='purple') 108 | if show_samples: 109 | samples = pred_dict['samples'] 110 | # Convert DataFrame to numpy array 111 | samples = samples.values if isinstance(samples, pd.DataFrame) else samples 112 | for i in range(min(10, samples.shape[0])): 113 | plt.plot(pred.index, samples[i], color='purple', alpha=0.3, linewidth=1) 114 | plt.legend(loc='upper left') 115 | if 'NLL/D' in pred_dict: 116 | nll = pred_dict['NLL/D'] 117 | if nll is not None: 118 | plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes, 119 | bbox=dict(facecolor='white', alpha=0.5)) 120 | plt.savefig(dir + f'{ds_name}{model_name}_prediction_meticulous.pdf', format='pdf') 121 | 122 | 123 | def metrics_used(test, dataset_name, original_pred, num_samples=10): 124 | ''' 125 | This function defines the metrics used for evaluating model performance. 126 | 127 | Args: 128 | dataset_name (str): Name of the dataset. 129 | original_pred (dict): Dictionary containing original predictions and samples. 130 | num_samples (int): Number of samples to consider. 131 | 132 | Returns: 133 | tuple: Mean values of MSE, MAE, MAPE, and R². 134 | ''' 135 | print("dataset_name: ", dataset_name) 136 | 137 | mse_amount = 0.0 138 | mae_amount = 0.0 139 | mape_amount = 0.0 140 | rsquare_amount = 0.0 141 | for i in range(num_samples): 142 | seq_pred = original_pred[dataset_name]['samples'].iloc[i, :] 143 | 144 | mse = mean_squared_error(test, seq_pred) 145 | mae = mean_absolute_error(test, seq_pred) 146 | mape = metrics.mean_absolute_percentage_error(test, seq_pred) * 100 147 | r2 = r2_score(test, seq_pred) 148 | 149 | mse_amount += mse 150 | mae_amount += mae 151 | mape_amount += mape 152 | rsquare_amount += r2 153 | 154 | mse_mean = mse_amount / num_samples 155 | mae_mean = mae_amount / num_samples 156 | mape_mean = mape_amount / num_samples 157 | r2_mean = rsquare_amount / num_samples 158 | 159 | # Print and plot values 160 | print("\n") 161 | print('Calculating metrics for each prediction and taking the mean:') 162 | print(f'MSE: {mse_mean}, MAE: {mae_mean}, MAPE: {mape_mean}, R²: {r2_mean}') 163 | print("\n") 164 | 165 | return mse_mean, mae_mean, mape_mean, r2_mean 166 | 167 | def fig_length(metric, metric_name='R^2', dir='length_impact', dataset_name='WineDataset'): 168 | ''' 169 | This function plots a line chart to visualize the length of the training set. 170 | 171 | Args: 172 | metric (list): List of metric values. 173 | metric_name (str): Name of the metric to be displayed on the y-axis. 174 | dataset_name (str): Name of the dataset. 175 | 176 | Returns: 177 | None 178 | ''' 179 | font_path = fm.findfont(fm.FontProperties(family='Times New Roman')) 180 | font_prop = fm.FontProperties(fname=font_path) 181 | _font_size = 38 182 | sns.set_style('whitegrid') 183 | sns.set(style="whitegrid", rc={"axes.grid.axis": "y", "axes.grid": True}) 184 | fig = plt.figure(figsize=(12,8)) 185 | ax1 = plt.gca() 186 | x = np.arange(len(metric)) 187 | sns.lineplot(x=x, y=metric, color='#d37981', alpha=1, linewidth=5, marker='o', markerfacecolor='w', markeredgecolor='#d37981', markersize=6) 188 | ax1.spines['right'].set_visible(False) 189 | ax1.spines['left'].set_visible(False) 190 | plt.xlabel("Index", fontproperties=font_prop, fontsize=_font_size+5) 191 | plt.ylabel(metric_name, fontproperties=font_prop, fontsize=_font_size+5) 192 | plt.legend(bbox_to_anchor=(0.5, 1.4), loc='upper center', fontsize=_font_size) 193 | plt.tight_layout() 194 | plt.savefig(dir + f"Length_of_Training_Set_Analysis_0407_{dataset_name}.png") 195 | plt.show() 196 | 197 | 198 | def plot_scatter_2d_list(data): 199 | ''' 200 | Plot 2D scatter plot from a two-dimensional list. 201 | 202 | Args: 203 | data (list): Two-dimensional list of data points. 204 | 205 | Returns: 206 | None 207 | ''' 208 | m = len(data) # Length of the first dimension 209 | n = len(data[0]) # Length of the second dimension 210 | 211 | # Initialize x and y coordinates 212 | x_coords = [] 213 | y_coords = [] 214 | 215 | # Traverse the two-dimensional list to extract x and y coordinates 216 | for i in range(m): 217 | for j in range(n): 218 | x_coords.append(i) # x-axis corresponds to the first dimension m 219 | y_coords.append(data[i][j]) # y-axis corresponds to the values in the two-dimensional list 220 | 221 | # Plot scatter plot 222 | plt.scatter(x_coords, y_coords) 223 | 224 | 225 | def fig_counterfactual(metric, metric_name='R^2', dataset_name='WineDataset', dir='counterfactual_analysis'): 226 | ''' 227 | Plot counterfactual analysis figure. 228 | 229 | Args: 230 | metric (list): Two-dimensional list of metric values. 231 | metric_name (str): Name of the metric. 232 | dataset_name (str): Name of the dataset. 233 | 234 | Returns: 235 | None 236 | ''' 237 | font_path = fm.findfont(fm.FontProperties(family='Times New Roman')) 238 | font_prop = fm.FontProperties(fname=font_path) 239 | _font_size = 38 240 | sns.set_style('whitegrid') 241 | sns.set(style="whitegrid", rc={"axes.grid.axis": "y", "axes.grid": True}) 242 | fig = plt.figure(figsize=(12, 8)) 243 | ax1 = plt.gca() 244 | x = np.arange(len(metric)) 245 | metric_mean = [sum(sublist) / len(sublist) for sublist in metric] 246 | 247 | sns.lineplot(x=x, y=metric_mean, color='#d37981', alpha=1, linewidth=5, marker='o', markerfacecolor='w', 248 | markeredgecolor='#d37981', markersize=6) 249 | lower = np.quantile(metric, 0.05, axis=1) 250 | upper = np.quantile(metric, 0.95, axis=1) 251 | plt.fill_between(x, lower, upper, alpha=0.3, color='purple') 252 | plot_scatter_2d_list(data=metric) 253 | 254 | ax1.spines['right'].set_visible(False) 255 | ax1.spines['left'].set_visible(False) 256 | plt.xlabel("Index", fontproperties=font_prop, fontsize=_font_size + 5) 257 | plt.ylabel(metric_name, fontproperties=font_prop, fontsize=_font_size + 5) 258 | plt.legend(bbox_to_anchor=(0.5, 1.4), loc='upper center', fontsize=_font_size) 259 | plt.tight_layout() 260 | plt.savefig(dir + f"Counterfactual_Analysis_0407_{dataset_name}.png") 261 | plt.show() 262 | 263 | 264 | def prediction_gemini(model_predict_fns, train, test, model_hypers, num_samples=10, whether_blanket=False, 265 | dataset_name='WineDataset', genai_key=None): 266 | ''' 267 | Perform Gemini model predictions. 268 | 269 | Args: 270 | model_predict_fns (dict): Dictionary of model prediction functions. 271 | train (pd.Series): Time series of training data. 272 | test (pd.Series): Time series of testing data. 273 | model_hypers (dict): Dictionary of model hyperparameters. 274 | num_samples (int): Number of samples for prediction. 275 | whether_blanket (bool): Whether to use blanket adjustments. 276 | dataset_name (str): Name of the dataset. 277 | 278 | Returns: 279 | tuple: Tuple containing dictionaries of Gemini predictions (gemini-pro and gemini-1.0-pro). 280 | ''' 281 | model_names = list(model_predict_fns.keys()) 282 | out_gemini_pro = {} # gemini-pro 283 | out_gemini_pro_number = {} # gemini-1.0-pro 284 | for model in model_names: 285 | model_hypers[model].update({'dataset_name': dataset_name}) # Add dataset_name to hyperparameters 286 | hypers = list(grid_iter(model_hypers[model])) # Generate hyperparameter combinations 287 | 288 | pred_dict = get_autotuned_predictions_data(train, test, hypers, num_samples, model_predict_fns[model], 289 | verbose=False, parallel=False, whether_blanket=whether_blanket, genai_key=genai_key) 290 | # This part of the code is yet to be verified, not confirmed if it can run; 291 | # Automatically autotune hyperparameters based on validation likelihood, as mentioned in the original text (better for trainable models) 292 | if model == 'gemini-pro': 293 | out_gemini_pro.update({dataset_name: pred_dict}) 294 | if model == 'gemini-1.0-pro': 295 | out_gemini_pro_number.update({dataset_name: pred_dict}) 296 | return out_gemini_pro, out_gemini_pro_number 297 | 298 | def opt_hyper_gemini(model_predict_fns, train, test, model_hypers, num_samples=10, whether_blanket=False, 299 | dataset_name='WineDataset', genai_key=None, temp_list=[0.2, 0.4, 0.6, 0.8, 1.0], prec_list=[2,3]): 300 | # 创建一个包含序号 i 的字典的列表 301 | gemini_hypers_list_0 = [{f'temp': temp_val} for i, temp_val in zip(range(len(temp_list)), temp_list)] 302 | gemini_hypers_list = [] 303 | 304 | output_metrics = [] 305 | 306 | for prec in prec_list: 307 | gemini_hypers_list_tmp = copy.deepcopy(gemini_hypers_list_0) 308 | for dict in gemini_hypers_list_tmp: 309 | dict.update({ 310 | 'alpha': 0.95, 311 | 'beta': 0.3, 312 | 'basic': [False], 313 | 'settings': [SerializerSettings(base=10, prec=prec, signed=True, half_bin_correction=True)], 314 | }) 315 | gemini_hypers_list.extend(gemini_hypers_list_tmp) 316 | 317 | for index, dict in enumerate(gemini_hypers_list): 318 | if index > 0: 319 | time.sleep(60) 320 | out_gemini_pro, out_gemini_pro_number = prediction_gemini(model_predict_fns=model_predict_fns, train=train, test=test, model_hypers=model_hypers, num_samples=num_samples, whether_blanket=whether_blanket, 321 | dataset_name=dataset_name, genai_key=genai_key) 322 | mse_mean, mae_mean, mape_mean, r2_mean = metrics_used(test=test, dataset_name=dataset_name, original_pred=out_gemini_pro_number, num_samples=num_samples) 323 | dict.update({'mse': mse_mean, 'mae': mae_mean, 'mape': mape_mean, 'r2': r2_mean}) 324 | output_metrics.append(dict) 325 | 326 | return output_metrics -------------------------------------------------------------------------------- /6_paraphrase_and_predict.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "ab96cd09472b0592", 7 | "metadata": { 8 | "ExecuteTime": { 9 | "end_time": "2024-02-16T06:14:59.176061200Z", 10 | "start_time": "2024-02-16T06:14:59.161063600Z" 11 | }, 12 | "collapsed": false, 13 | "jupyter": { 14 | "outputs_hidden": false 15 | } 16 | }, 17 | "outputs": [], 18 | "source": [ 19 | "import os\n", 20 | "\n", 21 | "from utils_paragraph import paraphrase_nlp, paraphrasing_predict_llm, paraphrase_initial, \\\n", 22 | " paraphrase_seq2lan, recover_lan2seq, paraphrasing_predict_llama\n", 23 | "\n", 24 | "os.environ['OMP_NUM_THREADS'] = '4'\n", 25 | "import numpy as np\n", 26 | "import pandas as pd\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "import openai\n", 29 | "\n", 30 | "with open('config.json', 'r', encoding='utf-8') as f:\n", 31 | " config = json.load(f)\n", 32 | "\n", 33 | "openai.api_key = config['OPENAI_API_KEY']\n", 34 | "openai.api_base = config['OPENAI_API_BASE']\n", 35 | "\n", 36 | "\n", 37 | "from data1.serialize import SerializerSettings\n", 38 | "from sklearn import metrics\n", 39 | "from models.darts import get_arima_predictions_data\n", 40 | "from models.llmtime import get_llmtime_predictions_data\n", 41 | "from data1.small_context import get_datasets, get_memorization_datasets, get_dataset\n", 42 | "from models.validation_likelihood_tuning import get_autotuned_predictions_data\n", 43 | "from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n", 44 | "import warnings\n", 45 | "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n", 46 | "\n", 47 | "# load_ext autoreload\n", 48 | "# autoreload 2\n" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 19, 54 | "id": "5b3a230fb7994d9f", 55 | "metadata": { 56 | "ExecuteTime": { 57 | "end_time": "2024-02-16T06:14:59.633542200Z", 58 | "start_time": "2024-02-16T06:14:59.613546900Z" 59 | }, 60 | "collapsed": false, 61 | "jupyter": { 62 | "outputs_hidden": false 63 | } 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "def plot_preds(train, test, pred_dict, model_name, ds_name, show_samples=False):\n", 68 | " pred = pred_dict['median']\n", 69 | " pred = pd.Series(pred, index=test.index)\n", 70 | " plt.figure(figsize=(8, 6), dpi=100)\n", 71 | " plt.plot(train)\n", 72 | " plt.plot(test, label='Truth', color='black')\n", 73 | " plt.plot(pred, label=model_name, color='purple')\n", 74 | " # shade 90% confidence interval\n", 75 | " samples = pred_dict['samples']\n", 76 | " lower = np.quantile(samples, 0.05, axis=0)\n", 77 | " upper = np.quantile(samples, 0.95, axis=0)\n", 78 | " plt.fill_between(pred.index, lower, upper, alpha=0.3, color='purple')\n", 79 | " if show_samples:\n", 80 | " samples = pred_dict['samples']\n", 81 | " # convert df to numpy array\n", 82 | " samples = samples.values if isinstance(samples, pd.DataFrame) else samples\n", 83 | " for i in range(min(10, samples.shape[0])):\n", 84 | " plt.plot(pred.index, samples[i], color='purple', alpha=0.3, linewidth=1)\n", 85 | " plt.legend(loc='upper left')\n", 86 | " if 'NLL/D' in pred_dict:\n", 87 | " nll = pred_dict['NLL/D']\n", 88 | " if nll is not None:\n", 89 | " plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes,\n", 90 | " bbox=dict(facecolor='white', alpha=0.5))\n", 91 | " plt.savefig(f'{ds_name}{model_name}givenname1.pdf', format='pdf')\n", 92 | "\n", 93 | "\n", 94 | "def plot_preds2(train, test, pred_dict, model_name, ds_name, show_samples=False):\n", 95 | " pred = pred_dict['median']\n", 96 | " pred = pd.Series(pred, index=test.index)\n", 97 | " plt.figure(figsize=(8, 6), dpi=100)\n", 98 | " # plt.plot(train)\n", 99 | " plt.plot(test, label='Truth', color='black')\n", 100 | " plt.plot(pred, label=model_name, color='purple')\n", 101 | " # shade 90% confidence interval\n", 102 | " samples = pred_dict['samples']\n", 103 | " lower = np.quantile(samples, 0.05, axis=0)\n", 104 | " upper = np.quantile(samples, 0.95, axis=0)\n", 105 | " plt.fill_between(pred.index, lower, upper, alpha=0.3, color='purple')\n", 106 | " if show_samples:\n", 107 | " samples = pred_dict['samples']\n", 108 | " # convert df to numpy array\n", 109 | " samples = samples.values if isinstance(samples, pd.DataFrame) else samples\n", 110 | " for i in range(min(10, samples.shape[0])):\n", 111 | " plt.plot(pred.index, samples[i], color='purple', alpha=0.3, linewidth=1)\n", 112 | " plt.legend(loc='upper left')\n", 113 | " if 'NLL/D' in pred_dict:\n", 114 | " nll = pred_dict['NLL/D']\n", 115 | " if nll is not None:\n", 116 | " plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes,\n", 117 | " bbox=dict(facecolor='white', alpha=0.5))\n", 118 | " plt.savefig(f'{ds_name}{model_name}givenname2.pdf', format='pdf')" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 20, 124 | "id": "167c8b811e892e23", 125 | "metadata": { 126 | "ExecuteTime": { 127 | "end_time": "2024-02-16T06:15:00.162580800Z", 128 | "start_time": "2024-02-16T06:15:00.145581100Z" 129 | }, 130 | "collapsed": false, 131 | "jupyter": { 132 | "outputs_hidden": false 133 | } 134 | }, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "Model_names: ['gpt-3.5-turbo-1106']\n" 141 | ] 142 | } 143 | ], 144 | "source": [ 145 | "gpt4_hypers = dict(\n", 146 | " alpha=0.3,\n", 147 | " basic=True,\n", 148 | " temp=1.0,\n", 149 | " top_p=0.8,\n", 150 | " settings=SerializerSettings(base=10, prec=3, signed=True, time_sep=', ', bit_sep='', minus_sign='-')\n", 151 | ")\n", 152 | "\n", 153 | "gpt3_hypers = dict(\n", 154 | " temp=0.7,\n", 155 | " alpha=0.95,\n", 156 | " beta=0.3,\n", 157 | " basic=False,\n", 158 | " settings=SerializerSettings(base=10, prec=3, signed=True, half_bin_correction=True)\n", 159 | ")\n", 160 | "\n", 161 | "promptcast_hypers = dict(\n", 162 | " temp=0.7,\n", 163 | " settings=SerializerSettings(base=10, prec=0, signed=True,\n", 164 | " time_sep=', ',\n", 165 | " bit_sep='',\n", 166 | " plus_sign='',\n", 167 | " minus_sign='-',\n", 168 | " half_bin_correction=False,\n", 169 | " decimal_point=''))\n", 170 | "\n", 171 | "arima_hypers = dict(p=[12, 30], d=[1, 2], q=[0])\n", 172 | "\n", 173 | "model_predict_fns = {\n", 174 | " 'gpt-3.5-turbo-1106': get_llmtime_predictions_data,\n", 175 | " # 'gpt-4-0125-preview': get_llmtime_predictions_data,\n", 176 | " # 'llama2-13b-chat': get_llmtime_predictions_data,\n", 177 | "}\n", 178 | "\n", 179 | "model_names = list(model_predict_fns.keys())\n", 180 | "print(\"Model_names:\", model_names)\n", 181 | "\n", 182 | "# Initial out dict\n", 183 | "\n", 184 | "datasets_list = [\n", 185 | " 'AirPassengersDataset',\n", 186 | " # 'AusBeerDataset',\n", 187 | " # 'GasRateCO2Dataset',\n", 188 | " 'MonthlyMilkDataset',\n", 189 | " 'SunspotsDataset',\n", 190 | " 'WineDataset',\n", 191 | " # 'WoolyDataset',\n", 192 | " # 'HeartRateDataset',\n", 193 | "\n", 194 | " # 'IstanbulTraffic',\n", 195 | " # 'TSMCStock',\n", 196 | " # 'TurkeyPower',\n", 197 | " # 'ETTh1Dataset',\n", 198 | " # 'ETTm2Dataset',\n", 199 | "]" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 22, 205 | "id": "bf53bd12360a84", 206 | "metadata": { 207 | "ExecuteTime": { 208 | "end_time": "2024-02-16T06:16:30.934974800Z", 209 | "start_time": "2024-02-16T06:16:03.557229900Z" 210 | }, 211 | "collapsed": false, 212 | "jupyter": { 213 | "outputs_hidden": false 214 | } 215 | }, 216 | "outputs": [ 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "Model name: gpt-3.5-turbo-1106\n", 222 | "dataset_name: AirPassengersDataset\n", 223 | "Round: 1\n", 224 | "Test_lan: 495.0,496.0,497.0,502.0,506.0,511.0,512.0,507.0,511.0,514.0,516.0,518.0,519.0,520.0,521.0,520.0,521.0,522.0,520.0,515.0,514.0,511.0,507.0,505.0,503.0,501.0,502.0,500.0,498.0,497.0,496.0,495.0,496.0,495.0,497.0,495.0,496.0,497.0,499.0,501.0,503.0,505.0,503.0,501.0,503.0,501.0,502.0,503.0,502.0,501.0,503.0,505.0\n", 225 | "test len: (29,)\n", 226 | "Seq_pred: (27,)\n", 227 | "Not enough sequences for prediction\n", 228 | "\n", 229 | "\n", 230 | "MSE: 0.0, MAE: 0.0, MAPE: 0.0, R²: 0.0\n", 231 | "\n", 232 | "\n", 233 | "dataset_name: MonthlyMilkDataset\n", 234 | "Round: 1\n", 235 | "Test_lan: 706.0,724.0,668.0,629.0,675.0,694.0,673.0,751.0,771.0,835.0,807.0,763.0,719.0,680.0,682.0,648.0,692.0,708.0,665.0,767.0,787.0,851.0,822.0,778.0,734.0\n", 236 | "test len: (34,)\n", 237 | "Seq_pred: (14,)\n", 238 | "Not enough sequences for prediction\n", 239 | "\n", 240 | "\n", 241 | "MSE: 0.0, MAE: 0.0, MAPE: 0.0, R²: 0.0\n", 242 | "\n", 243 | "\n", 244 | "dataset_name: SunspotsDataset\n", 245 | "Round: 1\n", 246 | "Test_lan: 57.0, 70.0, 77.0, 64.4, 50.0, 65.8, 69.3, 66.1, 57.0, 68.4, 68.4, 60.3, 58.3, 47.0, 33.9, 50.0, 21.2, 14.4, 33.0, 85.4, 106.8, 98.2, 95.9, 109.7, 89.3, 65.6, 95.7, 99.7, 104.6, 48.7, 77.3, 78.3, 70.0, 73.3, 74.3, 75.3, 96.0, 46.0, 95.3, 67.0, 56.0, 87.0, 48.3, 65.2, 74.1, 71.3, 62.4, 61.6, 68.5, 73.0, 70.0, 55.3, 56.0, 77.0, 53.0, 49.0, 31.8, 63.7, 33.2, 29.4, 14.0, 8.6, 21.3, 37.1, 47.3, 51.2, 53.2, 73.6, 37.6, 73.6, 64.7, 63.1, 77.5, 60.7, 68.4, 63.0, 27.5, 51.4, 45.2, 30.1, 48.6, 43.3, 44.8, 52.1, 24.6, 38.4, 69.2, 93.1, 77.1, 70.0, 156.0, 112.0, 92.3, 132.5, 89.9, 88.1, 127.8, 83.5, 83.5, 71.4, 73.1, 111.1, 68.5, 78.7, 67.0, 54.6, 76.4, 63.5, 65.6, 63.2, 37.7, 36.1, 73.1, 43.6, 25.6, 59.1, 33.7, 33.4, 46.6, 92.2, 50.8, 54.6, 65.4, 45.3, 52.1, 58.0, 54.0, 48.8, 45.8, 32.0, 30.2, 34.9, 26.3, 20.5, 18.8, 18.8, 18.8, 18.8, 27.2, 18.2, 22.5, 5.9, 5.3, 16.0, 24.8, 40.8, 53.5, 49.5, 57.3, 49.3, 39.3, 70.1, 68.3, 73.4, 88.8, 65.5, 85.1, 68.2, 49.3, 46.3, 74.6, 57.9, 53.4, 41.7, 39.8, 53.8, 47.2, 44.3, 55.9, 25.6, 31.6, 41.3, 30.6, 34.3, 58.5, 29.3, 42.9, 63.3, 51.5, 75.1, 83.7, 75.5, 95.5, 61.4, 76.2, 92.6, 93.9, 81.3, 78.0, 96.0, 59.7, 68.0, 69.6, 77.5, 57.5, 67.3, 67.1, 62.7, 56.2, 56.6, 38.3, 46.0, 34.5, 31.9, 23.6, 22.8, 17.6, 16.4, 22.3, 9.1, 7.9, 5.6, 2.8, 7.8, 16.0, 29.8, 44.8, 47.8, 58.1, 50.1, 47.9, 86.2, 64.1, 76.0, 115.3, 78.8, 97.3, 116.9, 107.7, 87.3, 146.7, 95.3, 89.3, 123.6, 109.3, 72.2, 112.2, 68.2, 101.0, 111.0, 75.6, 58.5, 78.3, 45.1, 35.2, 22.4, 25.4, 25.2, 25.2, 27.2, 18.3, 34.4, 14.9, 11.3, 33.3, 8.1, 43.5, 63.2, 48.2, 66.5, 72.7, 64.0, 80.5, 57.4, 65.1, 86.2, 70.1, 68.1, 85.5, 101.9, 68.7, 78.6, 84.6, 69.2, 57.6, 66.7, 60.1, 43.1, 30.3, 26.0, 25.0, 34.0, 10.0, 35.0, 65.0\n", 247 | "test len: (141,)\n", 248 | "Seq_pred: (144,)\n", 249 | "\n", 250 | "\n", 251 | "MSE: 5607.753333333333, MAE: 61.74751773049644, MAPE: 503.4133172146274, R²: -0.7545148138640969\n", 252 | "\n", 253 | "\n", 254 | "dataset_name: WineDataset\n", 255 | "Round: 1\n", 256 | "Test_lan: 31358.0, 31358.0 increasing to 34358.0, 34358.0 decreasing to 16695.0, 16695.0 increasing to 20624.0, 20624.0 decreasing to 19109.0, 19109.0 increasing to 22740.0, 22740.0 decreasing to 19010.0, 19010.0 increasing to 19270.0, 19270.0 increasing to 20933.0, 20933.0 increasing to 24257.0, 24257.0 increasing to 29161.0, 29161.0 decreasing to 28961.0, 28961.0 increasing to 36161.0, 36161.0 decreasing to 18615.0, 18615.0 increasing to 22583.0, 22583.0 increasing to 25383.0, 25383.0 decreasing to 22183.0, 22183.0 decreasing to 20338.0, 20338.0 increasing to 29161.0, 29161.0 decreasing to 26157.0\n", 257 | "test len: (36,)\n", 258 | "Seq_pred: (21,)\n", 259 | "Not enough sequences for prediction\n", 260 | "\n", 261 | "\n", 262 | "MSE: 0.0, MAE: 0.0, MAPE: 0.0, R²: 0.0\n", 263 | "\n", 264 | "\n", 265 | "-------------------------New Model\n" 266 | ] 267 | } 268 | ], 269 | "source": [ 270 | "out = {}\n", 271 | "datasets = get_datasets()\n", 272 | "num_samples = 1\n", 273 | "\n", 274 | "for model in model_names: # GPT-4 takes a about a minute to run\n", 275 | " print(\"Model name: \", model)\n", 276 | " steps = 500 # predict steps\n", 277 | " for dataset_name in datasets_list:\n", 278 | " mse_amount = 0.0\n", 279 | " mae_amount = 0.0\n", 280 | " mape_amount = 0.0\n", 281 | " rsquare_amount = 0.0\n", 282 | " print(\"dataset_name: \", dataset_name)\n", 283 | " for i in range(num_samples):\n", 284 | " print(\"Round: \", i+1)\n", 285 | " desp = paraphrase_initial(dataset_name)\n", 286 | " data = datasets[dataset_name]\n", 287 | " train, test = data\n", 288 | " Train_lan = paraphrase_seq2lan(train, desp)\n", 289 | " Test_lan = paraphrase_seq2lan(test, desp)\n", 290 | " seq_test = recover_lan2seq(Test_lan)\n", 291 | " seq_pred = paraphrasing_predict_llm(desp, Train_lan, steps, model)\n", 292 | "\n", 293 | " print(\"test len:\", test.shape)\n", 294 | " print(\"Seq_pred:\", seq_pred.shape)\n", 295 | " if seq_pred.shape >= test.shape:\n", 296 | " seq_pred = seq_pred[:len(test)]\n", 297 | " else:\n", 298 | " print(\"Not enough sequences for prediction\")\n", 299 | " break\n", 300 | " mse = mean_squared_error(test, seq_pred)\n", 301 | " mae = mean_absolute_error(test, seq_pred)\n", 302 | " mape = metrics.mean_absolute_percentage_error(test, seq_pred)*100\n", 303 | " r2 = r2_score(test, seq_pred)\n", 304 | "\n", 305 | " mse_amount += mse\n", 306 | " mae_amount += mae\n", 307 | " mape_amount += mape\n", 308 | " rsquare_amount += r2\n", 309 | "\n", 310 | " mse_mean = mse_amount/num_samples\n", 311 | " mae_mean = mae_amount/num_samples\n", 312 | " mape_mean = mape_amount/num_samples\n", 313 | " r2_mean = rsquare_amount/num_samples\n", 314 | " \n", 315 | " # print and plot values\n", 316 | " print(\"\\n\")\n", 317 | " print(f'MSE: {mse_mean}, MAE: {mae_mean}, MAPE: {mape_mean}, R²: {r2_mean}')\n", 318 | " print(\"\\n\")\n", 319 | " print('-------------------------New Model')" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 25, 325 | "id": "f05ecc5c3f2ed3eb", 326 | "metadata": { 327 | "ExecuteTime": { 328 | "end_time": "2024-02-16T06:19:17.128184200Z", 329 | "start_time": "2024-02-16T06:19:17.094185700Z" 330 | }, 331 | "collapsed": false, 332 | "jupyter": { 333 | "outputs_hidden": false 334 | } 335 | }, 336 | "outputs": [ 337 | { 338 | "data": { 339 | "text/plain": [ 340 | "Month\n", 341 | "1937-01-01 132.5\n", 342 | "1937-05-01 116.7\n", 343 | "1937-09-01 100.7\n", 344 | "1938-01-01 98.4\n", 345 | "1938-05-01 127.4\n", 346 | " ... \n", 347 | "1982-05-01 82.2\n", 348 | "1982-09-01 118.8\n", 349 | "1983-01-01 84.3\n", 350 | "1983-05-01 99.2\n", 351 | "1983-09-01 50.3\n", 352 | "Freq: 4MS, Name: Sunspots, Length: 141, dtype: float64" 353 | ] 354 | }, 355 | "execution_count": 25, 356 | "metadata": {}, 357 | "output_type": "execute_result" 358 | } 359 | ], 360 | "source": [ 361 | "data = datasets[\"SunspotsDataset\"]\n", 362 | "train, test = data\n", 363 | "test" 364 | ] 365 | } 366 | ], 367 | "metadata": { 368 | "kernelspec": { 369 | "display_name": "Python 3 (ipykernel)", 370 | "language": "python", 371 | "name": "python3" 372 | }, 373 | "language_info": { 374 | "codemirror_mode": { 375 | "name": "ipython", 376 | "version": 3 377 | }, 378 | "file_extension": ".py", 379 | "mimetype": "text/x-python", 380 | "name": "python", 381 | "nbconvert_exporter": "python", 382 | "pygments_lexer": "ipython3", 383 | "version": "3.11.8" 384 | } 385 | }, 386 | "nbformat": 4, 387 | "nbformat_minor": 5 388 | } 389 | -------------------------------------------------------------------------------- /data1/autoformer_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | from typing import List 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from sklearn.preprocessing import StandardScaler 9 | from pandas.tseries import offsets 10 | from pandas.tseries.frequencies import to_offset 11 | 12 | import warnings 13 | 14 | warnings.filterwarnings('ignore') 15 | 16 | class TimeFeature: 17 | def __init__(self): 18 | pass 19 | 20 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 21 | pass 22 | 23 | def __repr__(self): 24 | return self.__class__.__name__ + "()" 25 | 26 | 27 | class SecondOfMinute(TimeFeature): 28 | """Minute of hour encoded as value between [-0.5, 0.5]""" 29 | 30 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 31 | return index.second / 59.0 - 0.5 32 | 33 | 34 | class MinuteOfHour(TimeFeature): 35 | """Minute of hour encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.minute / 59.0 - 0.5 39 | 40 | 41 | class HourOfDay(TimeFeature): 42 | """Hour of day encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.hour / 23.0 - 0.5 46 | 47 | 48 | class DayOfWeek(TimeFeature): 49 | """Hour of day encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return index.dayofweek / 6.0 - 0.5 53 | 54 | 55 | class DayOfMonth(TimeFeature): 56 | """Day of month encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return (index.day - 1) / 30.0 - 0.5 60 | 61 | 62 | class DayOfYear(TimeFeature): 63 | """Day of year encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.dayofyear - 1) / 365.0 - 0.5 67 | 68 | 69 | class MonthOfYear(TimeFeature): 70 | """Month of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.month - 1) / 11.0 - 0.5 74 | 75 | 76 | class WeekOfYear(TimeFeature): 77 | """Week of year encoded as value between [-0.5, 0.5]""" 78 | 79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 80 | return (index.isocalendar().week - 1) / 52.0 - 0.5 81 | 82 | 83 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 84 | """ 85 | Returns a list of time features that will be appropriate for the given frequency string. 86 | Parameters 87 | ---------- 88 | freq_str 89 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 90 | """ 91 | 92 | features_by_offsets = { 93 | offsets.YearEnd: [], 94 | offsets.QuarterEnd: [MonthOfYear], 95 | offsets.MonthEnd: [MonthOfYear], 96 | offsets.Week: [DayOfMonth, WeekOfYear], 97 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 98 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 99 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 100 | offsets.Minute: [ 101 | MinuteOfHour, 102 | HourOfDay, 103 | DayOfWeek, 104 | DayOfMonth, 105 | DayOfYear, 106 | ], 107 | offsets.Second: [ 108 | SecondOfMinute, 109 | MinuteOfHour, 110 | HourOfDay, 111 | DayOfWeek, 112 | DayOfMonth, 113 | DayOfYear, 114 | ], 115 | } 116 | 117 | offset = to_offset(freq_str) 118 | 119 | for offset_type, feature_classes in features_by_offsets.items(): 120 | if isinstance(offset, offset_type): 121 | return [cls() for cls in feature_classes] 122 | 123 | supported_freq_msg = f""" 124 | Unsupported frequency {freq_str} 125 | The following frequencies are supported: 126 | Y - yearly 127 | alias: A 128 | M - monthly 129 | W - weekly 130 | D - daily 131 | B - business days 132 | H - hourly 133 | T - minutely 134 | alias: min 135 | S - secondly 136 | """ 137 | raise RuntimeError(supported_freq_msg) 138 | 139 | 140 | def time_features(dates, freq='h'): 141 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 142 | 143 | 144 | 145 | class Dataset_ETT_hour(Dataset): 146 | def __init__(self, root_path, flag='train', size=None, 147 | features='S', data_path='ETTh1.csv', 148 | target='OT', scale=True, timeenc=0, freq='h'): 149 | # size [seq_len, label_len, pred_len] 150 | # info 151 | if size == None: 152 | self.seq_len = 24 * 4 * 4 153 | self.label_len = 24 * 4 154 | self.pred_len = 24 * 4 155 | else: 156 | self.seq_len = size[0] 157 | self.label_len = size[1] 158 | self.pred_len = size[2] 159 | # init 160 | assert flag in ['train', 'test', 'val'] 161 | type_map = {'train': 0, 'val': 1, 'test': 2} 162 | self.set_type = type_map[flag] 163 | 164 | self.features = features 165 | self.target = target 166 | self.scale = scale 167 | self.timeenc = timeenc 168 | self.freq = freq 169 | 170 | self.root_path = root_path 171 | self.data_path = data_path 172 | self.__read_data__() 173 | 174 | def __read_data__(self): 175 | self.scaler = StandardScaler() 176 | df_raw = pd.read_csv(os.path.join(self.root_path, 177 | self.data_path)) 178 | 179 | num_train = int(len(df_raw) * 0.85) 180 | num_test = self.pred_len # int(len(df_raw) * 0.2) 181 | num_vali = len(df_raw) - num_train - num_test 182 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 183 | border2s = [num_train, num_train + num_vali, len(df_raw)] 184 | 185 | # border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len] 186 | # border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24] 187 | border1 = border1s[self.set_type] 188 | border2 = border2s[self.set_type] 189 | 190 | if self.features == 'M' or self.features == 'MS': 191 | cols_data = df_raw.columns[1:] 192 | df_data = df_raw[cols_data] 193 | elif self.features == 'S': 194 | df_data = df_raw[[self.target]] 195 | 196 | if self.scale: 197 | train_data = df_data[border1s[0]:border2s[0]] 198 | self.scaler.fit(train_data.values) 199 | data = self.scaler.transform(df_data.values) 200 | else: 201 | data = df_data.values 202 | 203 | df_stamp = df_raw[['date']][border1:border2] 204 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 205 | if self.timeenc == 0: 206 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 207 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 208 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 209 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 210 | data_stamp = df_stamp.drop(['date'], 1).values 211 | elif self.timeenc == 1: 212 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 213 | data_stamp = data_stamp.transpose(1, 0) 214 | 215 | self.data_x = data[border1:border2] 216 | self.data_y = data[border1:border2] 217 | self.data_stamp = data_stamp 218 | 219 | def __getitem__(self, index): 220 | s_begin = index 221 | s_end = s_begin + self.seq_len 222 | r_begin = s_end - self.label_len 223 | r_end = r_begin + self.label_len + self.pred_len 224 | 225 | seq_x = self.data_x[s_begin:s_end] 226 | seq_y = self.data_y[r_begin:r_end] 227 | seq_x_mark = self.data_stamp[s_begin:s_end] 228 | seq_y_mark = self.data_stamp[r_begin:r_end] 229 | 230 | return seq_x, seq_y, seq_x_mark, seq_y_mark 231 | 232 | def __len__(self): 233 | return len(self.data_x) - self.seq_len - self.pred_len + 1 234 | 235 | def inverse_transform(self, data): 236 | return self.scaler.inverse_transform(data) 237 | 238 | 239 | class Dataset_ETT_minute(Dataset): 240 | def __init__(self, root_path, flag='train', size=None, 241 | features='S', data_path='ETTm1.csv', 242 | target='OT', scale=True, timeenc=0, freq='t'): 243 | # size [seq_len, label_len, pred_len] 244 | # info 245 | if size == None: 246 | self.seq_len = 24 * 4 * 4 247 | self.label_len = 24 * 4 248 | self.pred_len = 24 * 4 249 | else: 250 | self.seq_len = size[0] 251 | self.label_len = size[1] 252 | self.pred_len = size[2] 253 | # init 254 | assert flag in ['train', 'test', 'val'] 255 | type_map = {'train': 0, 'val': 1, 'test': 2} 256 | self.set_type = type_map[flag] 257 | 258 | self.features = features 259 | self.target = target 260 | self.scale = scale 261 | self.timeenc = timeenc 262 | self.freq = freq 263 | 264 | self.root_path = root_path 265 | self.data_path = data_path 266 | self.__read_data__() 267 | 268 | def __read_data__(self): 269 | self.scaler = StandardScaler() 270 | df_raw = pd.read_csv(os.path.join(self.root_path, 271 | self.data_path)) 272 | 273 | num_train = int(len(df_raw) * 0.85) 274 | num_test = self.pred_len # int(len(df_raw) * 0.2) 275 | num_vali = len(df_raw) - num_train - num_test 276 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 277 | border2s = [num_train, num_train + num_vali, len(df_raw)] 278 | 279 | # border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len] 280 | # border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4] 281 | border1 = border1s[self.set_type] 282 | border2 = border2s[self.set_type] 283 | 284 | if self.features == 'M' or self.features == 'MS': 285 | cols_data = df_raw.columns[1:] 286 | df_data = df_raw[cols_data] 287 | elif self.features == 'S': 288 | df_data = df_raw[[self.target]] 289 | 290 | if self.scale: 291 | train_data = df_data[border1s[0]:border2s[0]] 292 | self.scaler.fit(train_data.values) 293 | data = self.scaler.transform(df_data.values) 294 | else: 295 | data = df_data.values 296 | 297 | df_stamp = df_raw[['date']][border1:border2] 298 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 299 | if self.timeenc == 0: 300 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 301 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 302 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 303 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 304 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1) 305 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15) 306 | data_stamp = df_stamp.drop(['date'], 1).values 307 | elif self.timeenc == 1: 308 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 309 | data_stamp = data_stamp.transpose(1, 0) 310 | 311 | self.data_x = data[border1:border2] 312 | self.data_y = data[border1:border2] 313 | self.data_stamp = data_stamp 314 | 315 | def __getitem__(self, index): 316 | s_begin = index 317 | s_end = s_begin + self.seq_len 318 | r_begin = s_end - self.label_len 319 | r_end = r_begin + self.label_len + self.pred_len 320 | 321 | seq_x = self.data_x[s_begin:s_end] 322 | seq_y = self.data_y[r_begin:r_end] 323 | seq_x_mark = self.data_stamp[s_begin:s_end] 324 | seq_y_mark = self.data_stamp[r_begin:r_end] 325 | 326 | return seq_x, seq_y, seq_x_mark, seq_y_mark 327 | 328 | def __len__(self): 329 | return len(self.data_x) - self.seq_len - self.pred_len + 1 330 | 331 | def inverse_transform(self, data): 332 | return self.scaler.inverse_transform(data) 333 | 334 | 335 | class Dataset_Custom(Dataset): 336 | def __init__(self, root_path, flag='train', size=None, 337 | features='S', data_path='ETTh1.csv', 338 | target='OT', scale=True, timeenc=0, freq='h'): 339 | # size [seq_len, label_len, pred_len] 340 | # info 341 | if size == None: 342 | self.seq_len = 24 * 4 * 4 343 | self.label_len = 24 * 4 344 | self.pred_len = 24 * 4 345 | else: 346 | self.seq_len = size[0] 347 | self.label_len = size[1] 348 | self.pred_len = size[2] 349 | # init 350 | assert flag in ['train', 'test', 'val'] 351 | type_map = {'train': 0, 'val': 1, 'test': 2} 352 | self.set_type = type_map[flag] 353 | 354 | self.features = features 355 | self.target = target 356 | self.scale = scale 357 | self.timeenc = timeenc 358 | self.freq = freq 359 | 360 | self.root_path = root_path 361 | self.data_path = data_path 362 | self.__read_data__() 363 | 364 | def __read_data__(self): 365 | self.scaler = StandardScaler() 366 | df_raw = pd.read_csv(os.path.join(self.root_path, 367 | self.data_path)) 368 | 369 | ''' 370 | df_raw.columns: ['date', ...(other features), target feature] 371 | ''' 372 | cols = list(df_raw.columns) 373 | cols.remove(self.target) 374 | cols.remove('date') 375 | df_raw = df_raw[['date'] + cols + [self.target]] 376 | # print(cols) 377 | num_train = int(len(df_raw) * 0.85) 378 | num_test = self.pred_len # int(len(df_raw) * 0.2) 379 | num_vali = len(df_raw) - num_train - num_test 380 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 381 | border2s = [num_train, num_train + num_vali, len(df_raw)] 382 | border1 = border1s[self.set_type] 383 | border2 = border2s[self.set_type] 384 | 385 | if self.features == 'M' or self.features == 'MS': 386 | cols_data = df_raw.columns[1:] 387 | df_data = df_raw[cols_data] 388 | elif self.features == 'S': 389 | df_data = df_raw[[self.target]] 390 | 391 | if self.scale: 392 | train_data = df_data[border1s[0]:border2s[0]] 393 | self.scaler.fit(train_data.values) 394 | data = self.scaler.transform(df_data.values) 395 | else: 396 | data = df_data.values 397 | 398 | df_stamp = df_raw[['date']][border1:border2] 399 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 400 | if self.timeenc == 0: 401 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 402 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 403 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 404 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 405 | data_stamp = df_stamp.drop(['date'], 1).values 406 | elif self.timeenc == 1: 407 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 408 | data_stamp = data_stamp.transpose(1, 0) 409 | 410 | self.data_x = data[border1:border2] 411 | self.data_y = data[border1:border2] 412 | self.data_stamp = data_stamp 413 | 414 | def __getitem__(self, index): 415 | s_begin = index 416 | s_end = s_begin + self.seq_len 417 | r_begin = s_end - self.label_len 418 | r_end = r_begin + self.label_len + self.pred_len 419 | 420 | seq_x = self.data_x[s_begin:s_end] 421 | seq_y = self.data_y[r_begin:r_end] 422 | seq_x_mark = self.data_stamp[s_begin:s_end] 423 | seq_y_mark = self.data_stamp[r_begin:r_end] 424 | 425 | return seq_x, seq_y, seq_x_mark, seq_y_mark 426 | 427 | def __len__(self): 428 | return len(self.data_x) - self.seq_len - self.pred_len + 1 429 | 430 | def inverse_transform(self, data): 431 | return self.scaler.inverse_transform(data) 432 | 433 | 434 | class Dataset_Pred(Dataset): 435 | def __init__(self, root_path, flag='pred', size=None, 436 | features='S', data_path='ETTh1.csv', 437 | target='OT', scale=True, inverse=False, timeenc=0, freq='15min', cols=None): 438 | # size [seq_len, label_len, pred_len] 439 | # info 440 | if size == None: 441 | self.seq_len = 24 * 4 * 4 442 | self.label_len = 24 * 4 443 | self.pred_len = 24 * 4 444 | else: 445 | self.seq_len = size[0] 446 | self.label_len = size[1] 447 | self.pred_len = size[2] 448 | # init 449 | assert flag in ['pred'] 450 | 451 | self.features = features 452 | self.target = target 453 | self.scale = scale 454 | self.inverse = inverse 455 | self.timeenc = timeenc 456 | self.freq = freq 457 | self.cols = cols 458 | self.root_path = root_path 459 | self.data_path = data_path 460 | self.__read_data__() 461 | 462 | def __read_data__(self): 463 | self.scaler = StandardScaler() 464 | df_raw = pd.read_csv(os.path.join(self.root_path, 465 | self.data_path)) 466 | ''' 467 | df_raw.columns: ['date', ...(other features), target feature] 468 | ''' 469 | if self.cols: 470 | cols = self.cols.copy() 471 | cols.remove(self.target) 472 | else: 473 | cols = list(df_raw.columns) 474 | cols.remove(self.target) 475 | cols.remove('date') 476 | df_raw = df_raw[['date'] + cols + [self.target]] 477 | border1 = len(df_raw) - self.seq_len 478 | border2 = len(df_raw) 479 | 480 | if self.features == 'M' or self.features == 'MS': 481 | cols_data = df_raw.columns[1:] 482 | df_data = df_raw[cols_data] 483 | elif self.features == 'S': 484 | df_data = df_raw[[self.target]] 485 | 486 | if self.scale: 487 | self.scaler.fit(df_data.values) 488 | data = self.scaler.transform(df_data.values) 489 | else: 490 | data = df_data.values 491 | 492 | tmp_stamp = df_raw[['date']][border1:border2] 493 | tmp_stamp['date'] = pd.to_datetime(tmp_stamp.date) 494 | pred_dates = pd.date_range(tmp_stamp.date.values[-1], periods=self.pred_len + 1, freq=self.freq) 495 | 496 | df_stamp = pd.DataFrame(columns=['date']) 497 | df_stamp.date = list(tmp_stamp.date.values) + list(pred_dates[1:]) 498 | if self.timeenc == 0: 499 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 500 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 501 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 502 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 503 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1) 504 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15) 505 | data_stamp = df_stamp.drop(['date'], 1).values 506 | elif self.timeenc == 1: 507 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 508 | data_stamp = data_stamp.transpose(1, 0) 509 | 510 | self.data_x = data[border1:border2] 511 | if self.inverse: 512 | self.data_y = df_data.values[border1:border2] 513 | else: 514 | self.data_y = data[border1:border2] 515 | self.data_stamp = data_stamp 516 | 517 | def __getitem__(self, index): 518 | s_begin = index 519 | s_end = s_begin + self.seq_len 520 | r_begin = s_end - self.label_len 521 | r_end = r_begin + self.label_len + self.pred_len 522 | 523 | seq_x = self.data_x[s_begin:s_end] 524 | if self.inverse: 525 | seq_y = self.data_x[r_begin:r_begin + self.label_len] 526 | else: 527 | seq_y = self.data_y[r_begin:r_begin + self.label_len] 528 | seq_x_mark = self.data_stamp[s_begin:s_end] 529 | seq_y_mark = self.data_stamp[r_begin:r_end] 530 | 531 | return seq_x, seq_y, seq_x_mark, seq_y_mark 532 | 533 | def __len__(self): 534 | return len(self.data_x) - self.seq_len + 1 535 | 536 | def inverse_transform(self, data): 537 | return self.scaler.inverse_transform(data) 538 | 539 | from torch.utils.data import DataLoader 540 | 541 | data_dict = { 542 | 'ETTh1': Dataset_ETT_hour, 543 | 'ETTh2': Dataset_ETT_hour, 544 | 'ETTm1': Dataset_ETT_minute, 545 | 'ETTm2': Dataset_ETT_minute, 546 | 'custom': Dataset_Custom, 547 | } 548 | 549 | 550 | def data_provider(args, flag): 551 | Data = data_dict[args.data] 552 | timeenc = 0 if args.embed != 'timeF' else 1 553 | 554 | if flag == 'test': 555 | shuffle_flag = False 556 | drop_last = False 557 | batch_size = args.batch_size 558 | freq = args.freq 559 | elif flag == 'pred': 560 | shuffle_flag = False 561 | drop_last = False 562 | batch_size = 1 563 | freq = args.freq 564 | Data = Dataset_Pred 565 | else: 566 | shuffle_flag = True 567 | drop_last = True 568 | batch_size = args.batch_size 569 | freq = args.freq 570 | 571 | data_set = Data( 572 | root_path=args.root_path, 573 | data_path=args.data_path, 574 | flag=flag, 575 | size=[args.seq_len, args.label_len, args.pred_len], 576 | features=args.features, 577 | target=args.target, 578 | timeenc=timeenc, 579 | freq=freq 580 | ) 581 | print(flag, len(data_set)) 582 | data_loader = DataLoader( 583 | data_set, 584 | batch_size=batch_size, 585 | shuffle=shuffle_flag, 586 | num_workers=args.num_workers, 587 | drop_last=drop_last) 588 | return data_set, data_loader 589 | --------------------------------------------------------------------------------