├── janestreet ├── __init__.py ├── models │ ├── __init__.py │ └── nn.py ├── setup_env.py ├── config.py ├── utils.py ├── metrics.py ├── kaggle.py ├── tracker.py ├── transformers.py ├── data_processor.py └── pipeline.py ├── scripts └── python │ ├── update_kaggle.py │ ├── run_ensemble.py │ ├── run_test_gap.py │ ├── run_cv.py │ ├── run_full_3.py │ ├── run_full_2.py │ └── monitor_kaggle.py ├── setup.py ├── pyproject.toml ├── README.md ├── .gitignore └── solution.md /janestreet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /janestreet/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/python/update_kaggle.py: -------------------------------------------------------------------------------- 1 | """Update Kaggle datasets: models and code.""" 2 | 3 | from janestreet.kaggle import update_dataset, upload_code 4 | from janestreet.config import PATH_MODELS, PATH_CODE 5 | from janestreet.setup_env import setup_environment 6 | 7 | setup_environment(track=False) 8 | 9 | update_dataset("janestreet2025-models", PATH_MODELS) 10 | upload_code("janestreet2025-code", PATH_CODE) 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup python package.""" 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name='janestreet', 7 | version='0.1', 8 | packages=find_packages(where='.'), 9 | package_dir={'': '.'}, 10 | exclude=[ 11 | "data", 12 | "notebooks", 13 | "models", 14 | "scripts", 15 | "*.notebooks", 16 | "data/*", 17 | "notebooks/*", 18 | "models/*", 19 | "scripts/*", 20 | "archive", 21 | "archive/*", 22 | "features", 23 | "features/*" 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /janestreet/setup_env.py: -------------------------------------------------------------------------------- 1 | """Set up project environment.""" 2 | 3 | import os 4 | 5 | 6 | def setup_environment(track: bool = False): 7 | """Set up project environment. 8 | 9 | Args: 10 | track (bool, optional): Whether to set up Weights & Biases. Defaults to True. 11 | """ 12 | # Dotenv 13 | print("Loading environment variables from .env file...") 14 | from dotenv import load_dotenv 15 | load_dotenv() 16 | 17 | # Set up Weights & Biases 18 | if track: 19 | print("Setting up Weights & Biases...") 20 | import wandb 21 | wandb.login(key=os.environ.get('WANDB_TOKEN')) 22 | 23 | print("Environment setup complete.") 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "janestreet" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Evgeniia "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.10" 10 | polars = "1.9.0" 11 | scikit-learn = "^1.5.2" 12 | numpy = "1.26.4" 13 | pandas = "^2.2.3" 14 | matplotlib = "^3.9.2" 15 | tqdm = "^4.66.5" 16 | wandb = "^0.18.5" 17 | python-dotenv = "^1.0.1" 18 | kaggle = "^1.6.17" 19 | grpcio = "^1.67.0" 20 | pyarrow = "^17.0.0" 21 | pytz = "^2024.2" 22 | torch = "2.4.0" 23 | torch_optimizer = "0.3.0" 24 | 25 | [tool.poetry.group.dev.dependencies] 26 | ipykernel = "^6.29.5" 27 | 28 | [build-system] 29 | requires = ["poetry-core"] 30 | build-backend = "poetry.core.masonry.api" 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kaggle Jane Street Real-Time Market Data Forecasting 2 | 3 | Solution for the [Jane Street 2024 Kaggle competition](https://www.kaggle.com/competitions/jane-street-real-time-market-data-forecasting/overview). 4 | 5 | A detailed description can be found in [`solution.md`](solution.md) and in [Kaggle discussion](https://www.kaggle.com/competitions/jane-street-real-time-market-data-forecasting/discussion/556542). 6 | 7 | ## Requirements 8 | 9 | - ~100GB RAM 10 | - 12GB GPU RAM 11 | 12 | ## Usage 13 | 14 | 1. Install requirements from [`pyproject.toml`](pyproject.toml). 15 | 2. Download the dataset from [Kaggle](https://www.kaggle.com/competitions/jane-street-real-time-market-data-forecasting/data). 16 | 3. Set paths and other config variables in [`janestreet/config.py`](janestreet/config.py). 17 | 18 | ### Scripts 19 | 20 | - [`run_cv.py`](scripts/python/run_cv.py) - Estimate model on cross-validation. 21 | - [`run_full.py`](scripts/python/run_full.py) - Estimate model for the final submission (on the whole sample). 22 | - [`run_ensemble.py`](scripts/python/run_ensemble.py) - Evaluate ensemble of models on CV. 23 | - [`run_test_gap.py`](scripts/python/run_test_gap.py) - Test model on a sample of the last 200 dates with a gap of 200 dates. 24 | 25 | #### Additional scripts 26 | 27 | - [`monitor_kaggle.py`](scripts/python/monitor_kaggle.py) - Monitor kaggle submissions and send notifications when completed. 28 | - [`update_kaggle.py`](scripts/python/update_kaggle.py) - Push code and models to Kaggle datasets to be used in submission. 29 | -------------------------------------------------------------------------------- /janestreet/config.py: -------------------------------------------------------------------------------- 1 | """Configuration file.""" 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | # Determine the execution environment based on environment variables 7 | KAGGLE = 'KAGGLE_URL_BASE' in os.environ 8 | VASTAI = not KAGGLE 9 | 10 | # Define base paths for different environments (change if needed) 11 | # Data and models can be stored on a different volume 12 | # Path with data should contain subdirectoriy "data" 13 | # with data from https://www.kaggle.com/competitions/jane-street-real-time-market-data-forecasting/data 14 | base_paths = { 15 | "VASTAI": Path("/home/janestreet2024"), 16 | "VASTAI_DATA": Path("/workspace/kaggle/janestreet"), 17 | "KAGGLE": Path("/kaggle/input"), 18 | } 19 | 20 | # Set paths based on the environment (change if needed) 21 | if VASTAI: 22 | base_path = base_paths["VASTAI"] 23 | base_path_data = base_paths["VASTAI_DATA"] 24 | PATH_DATA = base_path_data / "data" 25 | PATH_MODELS = base_path_data / "models" 26 | PATH_CODE = base_path / "dist/janestreet-0.1-py3-none-any.whl" 27 | elif KAGGLE: 28 | base_path = base_paths["KAGGLE"] 29 | base_path_data = base_paths["KAGGLE"] 30 | PATH_DATA = base_path / "jane-street-real-time-market-data-forecasting" 31 | PATH_MODELS = base_path / "janestreet2025-models" 32 | PATH_CODE = base_path / "janestreet2025-code/janestreet-0.1-py3-none-any.whl" 33 | else: 34 | raise ValueError("Unknown environment") 35 | 36 | PATHS_DATA = { 37 | "train": PATH_DATA / "train", 38 | "test": PATH_DATA / "test", 39 | } 40 | 41 | # Set other configuration variables 42 | # Wandb (to track experiments, not required) 43 | WANDB_PROJECT = "kaggle_janestreet" 44 | 45 | # Kaggle (to push code and models, not required) 46 | KAGGLE_USERNAME = "eivolkova" 47 | 48 | # Default random seed 49 | RANDOM_SEED = 42 50 | 51 | # Data column names (do not change) 52 | COL_TARGET = "responder_6" 53 | COL_ID = "symbol_id" 54 | COL_DATE = "date_id" 55 | COL_TIME = "time_id" 56 | COL_WEIGHT = "weight" 57 | COLS_RESPONDERS = [f"responder_{i}" for i in range(11)] 58 | -------------------------------------------------------------------------------- /scripts/python/run_ensemble.py: -------------------------------------------------------------------------------- 1 | """Run ensemble. 2 | """ 3 | import numpy as np 4 | 5 | from janestreet.setup_env import setup_environment 6 | from janestreet.pipeline import FullPipeline, PipelineCV, PipelineEnsemble 7 | from janestreet.models.nn import NN 8 | from janestreet.data_processor import DataProcessor 9 | from janestreet.tracker import WandbTracker 10 | 11 | TRACK = False 12 | 13 | MODEL_NAME = "ensemble" 14 | COMMENT = "" 15 | CATEGORY = "model_ver27_ts" 16 | N_SPLITS = 2 17 | START = 1000 18 | MODEL_NAMES = [ 19 | "gru_2.0_700_cv", 20 | "gru_2.1_700_cv", 21 | "gru_2.2_700_cv", 22 | "gru_3.0_700_cv", 23 | "gru_3.1_700_cv", 24 | "gru_3.2_700_cv", 25 | ] 26 | WEIGHTS = np.array([1.0]*len(MODEL_NAMES))/ len(MODEL_NAMES) 27 | REFIT_MODELS = [True] * len(MODEL_NAMES) 28 | 29 | setup_environment(TRACK) 30 | 31 | data_processor = DataProcessor(MODEL_NAME, skip_days=START) 32 | df = data_processor.get_train_data() 33 | 34 | print(MODEL_NAMES) 35 | print(WEIGHTS) 36 | print(REFIT_MODELS) 37 | 38 | models = [] 39 | for i, model_name in enumerate(MODEL_NAMES): 40 | pipeline = FullPipeline( 41 | NN(), 42 | name=model_name, 43 | run_name="full", 44 | load_model=True, 45 | features=None, 46 | refit=True, 47 | change_lr=False, 48 | ) 49 | models.append(pipeline) 50 | pipeline = PipelineEnsemble(models, WEIGHTS, REFIT_MODELS) 51 | 52 | if TRACK: 53 | params = {} 54 | params["n_splits"] = N_SPLITS 55 | params["model_type"] = "ensemble" 56 | params["models"] = MODEL_NAMES 57 | params["weights"] = WEIGHTS 58 | params["n_models"] = len(MODEL_NAMES) 59 | wandb_tracker = WandbTracker( 60 | MODEL_NAME, 61 | params, 62 | category=CATEGORY, 63 | comment=COMMENT 64 | ) 65 | wandb_tracker.init_run([]) 66 | else: 67 | wandb_tracker = None 68 | 69 | cv = PipelineCV(pipeline, wandb_tracker, n_splits=N_SPLITS) 70 | scores = cv.fit(df, verbose=True) 71 | 72 | if wandb_tracker is not None: 73 | wandb_tracker.finish() 74 | -------------------------------------------------------------------------------- /janestreet/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions. 2 | 3 | This module provides utility functions for executing shell commands 4 | and creating folders. 5 | 6 | Functions: 7 | run_shell_command: Executes a shell command and prints the output in real-time. 8 | create_folder: Creates a folder, with an option to remove existing ones. 9 | """ 10 | 11 | import os 12 | import shutil 13 | import subprocess 14 | 15 | 16 | def run_shell_command(command: str, cwd: str = None) -> None: 17 | """Executes a shell command and prints the output/errors in real-time. 18 | 19 | Args: 20 | command (str): The shell command to execute. 21 | cwd (str, optional): The working directory where the command should be executed. 22 | Defaults to None. 23 | """ 24 | try: 25 | # Print the command itself 26 | print(f"Running command: {command}") 27 | 28 | # Start the process 29 | process = subprocess.Popen( 30 | command, 31 | cwd=cwd, 32 | stdout=subprocess.PIPE, 33 | stderr=subprocess.PIPE, 34 | shell=True, 35 | universal_newlines=True 36 | ) 37 | 38 | # Print the output in real-time 39 | while True: 40 | output = process.stdout.readline() 41 | if output == '' and process.poll() is not None: 42 | break 43 | if output: 44 | print(output.strip()) 45 | 46 | # Capture any remaining errors 47 | stderr = process.stderr.read() 48 | if stderr: 49 | print("Errors:\n", stderr) 50 | 51 | except Exception as e: 52 | print(f"An error occurred: {e}") 53 | 54 | 55 | def create_folder(path: str, rm: bool = False) -> None: 56 | """Creates a folder. 57 | 58 | Args: 59 | path (str): Path to the folder. 60 | rm (bool, optional): Whether to remove the folder if it already exists. Defaults to False. 61 | """ 62 | if rm: 63 | if os.path.exists(path): 64 | shutil.rmtree(path) 65 | os.makedirs(path, exist_ok=True) 66 | -------------------------------------------------------------------------------- /scripts/python/run_test_gap.py: -------------------------------------------------------------------------------- 1 | """Load model and test it on a sample of last 200 days with a gap of 200 days. 2 | """ 3 | import numpy as np 4 | import polars as pl 5 | from tqdm.auto import tqdm 6 | 7 | from janestreet.setup_env import setup_environment 8 | from janestreet.pipeline import FullPipeline 9 | from janestreet.models.nn import NN 10 | from janestreet.config import COL_DATE, COL_WEIGHT, COL_TARGET 11 | from janestreet.data_processor import DataProcessor 12 | from janestreet.metrics import r2_weighted 13 | from janestreet.transformers import PolarsTransformer 14 | 15 | 16 | MODEL_TYPE = "gru" 17 | NUM = 3 18 | 19 | setup_environment() 20 | 21 | data_processor = DataProcessor(f"{MODEL_TYPE}_{NUM}.{0}_{700}_cv", skip_days=1200) 22 | df = data_processor.get_train_data() 23 | 24 | features = data_processor.features 25 | 26 | print(features) 27 | print(f"Number of features: {len(features)}") 28 | 29 | preds = [] 30 | for NUM in [2, 3]: 31 | for SEED in range(3): 32 | MODEL_NAME = f"{MODEL_TYPE}_{NUM}.{SEED}_{700}_cv" 33 | print(MODEL_NAME) 34 | 35 | model = NN(random_seed=SEED) 36 | 37 | pipeline = FullPipeline( 38 | model, 39 | preprocessor=PolarsTransformer(features), 40 | run_name="fold0", 41 | name=MODEL_NAME, 42 | load_model=True, 43 | features=features, 44 | refit=True, 45 | change_lr=False, 46 | ) 47 | df_test = df.filter((pl.col(COL_DATE) >= 1499)&(pl.col(COL_DATE) < 1699)) 48 | df_valid = df.filter(pl.col(COL_DATE) >= 1299) 49 | 50 | pipeline.fit(verbose=True) 51 | 52 | cnt_dates = 0 53 | preds_m = [] 54 | dates = np.unique(df_valid.select(pl.col(COL_DATE)).to_series().to_numpy()) 55 | for date_id in tqdm(dates): 56 | df_valid_date = df.filter(pl.col(COL_DATE) == date_id) 57 | if pipeline.refit & (cnt_dates > 0): 58 | df_valid_time = df.filter(pl.col(COL_DATE) == date_id-1) 59 | pipeline.update(df_valid_time) 60 | if date_id >= 1499: 61 | preds_i, hidden = pipeline.predict(df_valid_date, n_times=None) 62 | preds_m += list(preds_i) 63 | cnt_dates += 1 64 | 65 | preds_m = np.array(preds_m) 66 | preds.append(preds_m) 67 | 68 | preds = np.mean(preds, axis=0) 69 | y = df_test.select(pl.col(COL_TARGET)).to_series().to_numpy() 70 | weight = df_test.select(pl.col(COL_WEIGHT)).to_series().to_numpy() 71 | score = r2_weighted(y, preds, weight) 72 | print(f"Score: {score:.5f}") 73 | -------------------------------------------------------------------------------- /scripts/python/run_cv.py: -------------------------------------------------------------------------------- 1 | """Run model on CV. 2 | """ 3 | 4 | from janestreet.setup_env import setup_environment 5 | from janestreet.pipeline import FullPipeline, PipelineCV 6 | from janestreet.models.nn import NN 7 | from janestreet.data_processor import DataProcessor 8 | from janestreet.tracker import WandbTracker 9 | from janestreet.transformers import PolarsTransformer 10 | 11 | TRACK = False 12 | COMMENT = "" 13 | CATEGORY = "model_ver27_ts" 14 | 15 | MODEL_TYPE = "gru" 16 | NUM = 3 17 | LOAD_MODEL = False 18 | REFIT = True 19 | N_SPLITS = 2 20 | START = 700 21 | TRAIN_SIZE = None 22 | 23 | setup_environment(TRACK) 24 | 25 | data_processor = DataProcessor(f"{MODEL_TYPE}_{NUM}.{0}_{START}_cv", skip_days=START) 26 | df = data_processor.get_train_data() 27 | 28 | features = data_processor.features 29 | 30 | params_nn = { 31 | "model_type": "gru", 32 | 33 | # ### Model 1 34 | # "hidden_sizes": [250, 150, 150], 35 | # "dropout_rates": [0.0, 0.0, 0.0], 36 | # "hidden_sizes_linear": [], 37 | # "dropout_rates_linear": [], 38 | # ### 39 | 40 | ### Model 2 41 | "hidden_sizes": [500], 42 | "dropout_rates": [0.3, 0.0, 0.0], 43 | "hidden_sizes_linear": [500, 300], 44 | "dropout_rates_linear": [0.2, 0.1], 45 | ### 46 | 47 | "batch_size": 1, 48 | "early_stopping_patience": 1, 49 | "lr_refit": 0.0003, 50 | "lr": 0.0005, 51 | 52 | "epochs": 1000, 53 | "early_stopping": True, 54 | "lr_patience": 1000, 55 | "lr_factor": 0.5, 56 | } 57 | 58 | print(features) 59 | print(f"Number of features: {len(features)}") 60 | print(params_nn) 61 | 62 | for SEED in range(3): 63 | MODEL_NAME = f"{MODEL_TYPE}_{NUM}.{SEED}_{700}_cv" 64 | 65 | model = NN(**params_nn, random_seed=SEED) 66 | 67 | pipeline = FullPipeline( 68 | model, 69 | preprocessor=PolarsTransformer(features), 70 | run_name="full", 71 | name=MODEL_NAME, 72 | load_model=LOAD_MODEL, 73 | features=features, 74 | refit=REFIT, 75 | change_lr=True, 76 | ) 77 | 78 | wandb_tracker = None 79 | if TRACK: 80 | params = dict(params_nn) 81 | params["n_splits"] = N_SPLITS 82 | params["seed"] = SEED 83 | params["start"] = START 84 | wandb_tracker = WandbTracker( 85 | MODEL_NAME, 86 | params, 87 | category=CATEGORY, 88 | comment=COMMENT 89 | ) 90 | wandb_tracker.init_run(features) 91 | 92 | cv = PipelineCV(pipeline, wandb_tracker, n_splits=N_SPLITS, train_size=TRAIN_SIZE) 93 | scores = cv.fit(df, verbose=True) 94 | 95 | if wandb_tracker is not None: 96 | wandb_tracker.finish() 97 | -------------------------------------------------------------------------------- /janestreet/metrics.py: -------------------------------------------------------------------------------- 1 | """Weighted R² functions.""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def r2_weighted( 9 | y_true: np.array, 10 | y_pred: np.array, 11 | sample_weight: np.array 12 | ) -> float: 13 | """Compute the weighted R² score. 14 | 15 | Args: 16 | y_true (np.array): Ground truth values. 17 | y_pred (np.array): Predicted values. 18 | sample_weight (np.array): Weights for each observation. 19 | 20 | Returns: 21 | float: Weighted R² score. 22 | """ 23 | r2 = 1 - np.average((y_pred - y_true) ** 2, weights=sample_weight) / ( 24 | np.average((y_true) ** 2, weights=sample_weight) + 1e-38 25 | ) 26 | return r2 27 | 28 | def r2_weighted_torch( 29 | y_true: torch.Tensor, 30 | y_pred: torch.Tensor, 31 | sample_weight: torch.Tensor 32 | ) -> torch.Tensor: 33 | """Compute the weighted R² score using PyTorch tensors. 34 | 35 | Args: 36 | y_true (torch.Tensor): Ground truth tensor. 37 | y_pred (torch.Tensor): Predicted tensor. 38 | sample_weight (torch.Tensor): Weights for each observation (same shape as y_true). 39 | 40 | Returns: 41 | torch.Tensor: Weighted R² score. 42 | """ 43 | numerator = torch.sum(sample_weight * (y_pred - y_true) ** 2) 44 | denominator = torch.sum(sample_weight * (y_true) ** 2) + 1e-38 45 | r2 = 1 - (numerator / denominator) 46 | return r2 47 | 48 | class WeightedR2Loss(nn.Module): 49 | """PyTorch loss function for weighted R².""" 50 | def __init__(self, epsilon: float = 1e-38) -> None: 51 | """ 52 | Initialize the WeightedR2Loss class. 53 | 54 | Args: 55 | epsilon (float, optional): Small constant added to the denominator 56 | for numerical stability. Defaults to 1e-38. 57 | """ 58 | super(WeightedR2Loss, self).__init__() 59 | self.epsilon = epsilon 60 | 61 | def forward( 62 | self, 63 | y_pred: torch.Tensor, 64 | y_true: torch.Tensor, 65 | weights: torch.Tensor 66 | ) -> torch.Tensor: 67 | """Compute the weighted R² loss. 68 | 69 | Args: 70 | y_true (torch.Tensor): Ground truth tensor. 71 | y_pred (torch.Tensor): Predicted tensor. 72 | weights (torch.Tensor): Weights for each observation (same shape as y_true). 73 | 74 | Returns: 75 | torch.Tensor: Computed weighted R² loss. 76 | """ 77 | numerator = torch.sum(weights * (y_pred - y_true) ** 2) 78 | denominator = torch.sum(weights * (y_true) ** 2) + 1e-38 79 | loss = numerator / denominator 80 | return loss 81 | -------------------------------------------------------------------------------- /scripts/python/run_full_3.py: -------------------------------------------------------------------------------- 1 | """Run model for final submission. 2 | """ 3 | 4 | import numpy as np 5 | import polars as pl 6 | from tqdm.auto import tqdm 7 | 8 | from janestreet.setup_env import setup_environment 9 | from janestreet.pipeline import FullPipeline 10 | from janestreet.models.nn import NN 11 | from janestreet.config import COL_DATE, COL_WEIGHT, COL_TARGET 12 | from janestreet.data_processor import DataProcessor 13 | from janestreet.metrics import r2_weighted 14 | from janestreet.transformers import PolarsTransformer 15 | 16 | 17 | MODEL_TYPE = "gru" 18 | START = 500 19 | NUM = 3 20 | EPOCHS_LS = [8]*10 21 | COMMENT = "" 22 | CATEGORY = "model_ver22_ts" 23 | 24 | LOAD_MODEL = False 25 | REFIT = True 26 | 27 | setup_environment() 28 | 29 | data_processor = DataProcessor( 30 | f"{MODEL_TYPE}_{NUM}.0_700", 31 | skip_days=START, 32 | transformer=PolarsTransformer() 33 | ) 34 | df = data_processor.get_train_data() 35 | features = data_processor.features 36 | 37 | params_nn = { 38 | "model_type": "gru", 39 | 40 | ### Model 1 41 | "hidden_sizes": [250, 150, 150], 42 | "dropout_rates": [0.0, 0.0, 0.0], 43 | "hidden_sizes_linear": [], 44 | "dropout_rates_linear": [], 45 | ### 46 | 47 | ### Model 2 48 | # "hidden_sizes": [500], 49 | # "dropout_rates": [0.3, 0.0, 0.0], 50 | # "hidden_sizes_linear": [500, 300], 51 | # "dropout_rates_linear": [0.2, 0.1], 52 | ### 53 | 54 | "batch_size": 1, 55 | "early_stopping_patience": 1, 56 | "lr_refit": 0.0003, 57 | "lr": 0.0005, 58 | 59 | "epochs": 100, 60 | "early_stopping": True, 61 | "lr_patience": 10, 62 | "lr_factor": 0.5, 63 | } 64 | 65 | print(features) 66 | print(f"Number of features: {len(features)}") 67 | print(params_nn) 68 | 69 | for SEED in range(3): 70 | MODEL_NAME = f"{MODEL_TYPE}_{NUM}.{SEED}_700" 71 | EPOCHS = EPOCHS_LS[SEED] 72 | 73 | print(MODEL_NAME) 74 | 75 | if EPOCHS is not None: 76 | params_nn["early_stopping"] = False 77 | params_nn["epochs"] = EPOCHS 78 | print(f"Running final model with {params_nn['epochs']} epochs.") 79 | model = NN(**params_nn, random_seed=SEED) 80 | 81 | df_train = df.filter(pl.col(COL_DATE) >= START+200) 82 | df_valid = df.filter(pl.col(COL_DATE) < START+200) 83 | 84 | pipeline = FullPipeline( 85 | model, 86 | preprocessor=None, 87 | run_name="full", 88 | name=MODEL_NAME, 89 | load_model=LOAD_MODEL, 90 | features=features, 91 | refit=REFIT, 92 | ) 93 | pipeline.fit(df_train, df_valid, verbose=True) 94 | 95 | cnt_dates = 0 96 | preds = [] 97 | dates = np.unique(df_valid.select(pl.col(COL_DATE)).to_series().to_numpy()) 98 | for date_id in tqdm(dates): 99 | df_valid_date = df.filter(pl.col(COL_DATE) == date_id) 100 | if pipeline.refit & (cnt_dates > 0): 101 | df_valid_upd = df.filter(pl.col(COL_DATE) == date_id-1) 102 | pipeline.update(df_valid_upd) 103 | preds_i, hidden = pipeline.predict(df_valid_date, n_times=None) 104 | preds += list(preds_i) 105 | cnt_dates += 1 106 | 107 | preds = np.array(preds) 108 | y = df_valid.select(pl.col(COL_TARGET)).to_series().to_numpy() 109 | weight = df_valid.select(pl.col(COL_WEIGHT)).to_series().to_numpy() 110 | score = r2_weighted(y, preds, weight) 111 | print(f"Score: {score:.5f}") 112 | -------------------------------------------------------------------------------- /scripts/python/run_full_2.py: -------------------------------------------------------------------------------- 1 | """Run model for final submission. 2 | """ 3 | 4 | import numpy as np 5 | import polars as pl 6 | from tqdm.auto import tqdm 7 | 8 | from janestreet.setup_env import setup_environment 9 | from janestreet.pipeline import FullPipeline 10 | from janestreet.models.nn import NN 11 | from janestreet.config import COL_DATE, COL_WEIGHT, COL_TARGET 12 | from janestreet.data_processor import DataProcessor 13 | from janestreet.metrics import r2_weighted 14 | from janestreet.transformers import PolarsTransformer 15 | 16 | 17 | MODEL_TYPE = "gru" 18 | START = 500 19 | NUM = 2 20 | EPOCHS_LS = [8]*10 21 | COMMENT = "" 22 | CATEGORY = "model_ver22_ts" 23 | 24 | LOAD_MODEL = False 25 | REFIT = True 26 | 27 | setup_environment() 28 | 29 | data_processor = DataProcessor( 30 | f"{MODEL_TYPE}_{NUM}.0_700", 31 | skip_days=START, 32 | transformer=PolarsTransformer() 33 | ) 34 | df = data_processor.get_train_data() 35 | features = data_processor.features 36 | 37 | params_nn = { 38 | "model_type": "gru", 39 | 40 | # ### Model 1 41 | # "hidden_sizes": [250, 150, 150], 42 | # "dropout_rates": [0.0, 0.0, 0.0], 43 | # "hidden_sizes_linear": [], 44 | # "dropout_rates_linear": [], 45 | # ### 46 | 47 | ### Model 2 48 | "hidden_sizes": [500], 49 | "dropout_rates": [0.3, 0.0, 0.0], 50 | "hidden_sizes_linear": [500, 300], 51 | "dropout_rates_linear": [0.2, 0.1], 52 | ### 53 | 54 | "batch_size": 1, 55 | "early_stopping_patience": 1, 56 | "lr_refit": 0.0003, 57 | "lr": 0.0005, 58 | 59 | "epochs": 100, 60 | "early_stopping": True, 61 | "lr_patience": 10, 62 | "lr_factor": 0.5, 63 | } 64 | 65 | print(features) 66 | print(f"Number of features: {len(features)}") 67 | print(params_nn) 68 | 69 | for SEED in range(3): 70 | MODEL_NAME = f"{MODEL_TYPE}_{NUM}.{SEED}_700" 71 | EPOCHS = EPOCHS_LS[SEED] 72 | 73 | print(MODEL_NAME) 74 | 75 | if EPOCHS is not None: 76 | params_nn["early_stopping"] = False 77 | params_nn["epochs"] = EPOCHS 78 | print(f"Running final model with {params_nn['epochs']} epochs.") 79 | model = NN(**params_nn, random_seed=SEED) 80 | 81 | df_train = df.filter(pl.col(COL_DATE) >= START+200) 82 | df_valid = df.filter(pl.col(COL_DATE) < START+200) 83 | 84 | pipeline = FullPipeline( 85 | model, 86 | preprocessor=None, 87 | run_name="full", 88 | name=MODEL_NAME, 89 | load_model=LOAD_MODEL, 90 | features=features, 91 | refit=REFIT, 92 | ) 93 | pipeline.fit(df_train, df_valid, verbose=True) 94 | 95 | cnt_dates = 0 96 | preds = [] 97 | dates = np.unique(df_valid.select(pl.col(COL_DATE)).to_series().to_numpy()) 98 | for date_id in tqdm(dates): 99 | df_valid_date = df.filter(pl.col(COL_DATE) == date_id) 100 | if pipeline.refit & (cnt_dates > 0): 101 | df_valid_upd = df.filter(pl.col(COL_DATE) == date_id-1) 102 | pipeline.update(df_valid_upd) 103 | preds_i, hidden = pipeline.predict(df_valid_date, n_times=None) 104 | preds += list(preds_i) 105 | cnt_dates += 1 106 | 107 | preds = np.array(preds) 108 | y = df_valid.select(pl.col(COL_TARGET)).to_series().to_numpy() 109 | weight = df_valid.select(pl.col(COL_WEIGHT)).to_series().to_numpy() 110 | score = r2_weighted(y, preds, weight) 111 | print(f"Score: {score:.5f}") 112 | -------------------------------------------------------------------------------- /janestreet/kaggle.py: -------------------------------------------------------------------------------- 1 | """Kaggle datasets utility functions. 2 | 3 | This module provides utility functions for updating and uploading Kaggle datasets. 4 | 5 | Functions: 6 | update_dataset: Updates a Kaggle dataset with specific subfolders. 7 | upload_code: Uploads a Python package to a Kaggle dataset. 8 | """ 9 | 10 | import json 11 | import os 12 | import shutil 13 | import tempfile 14 | from pathlib import Path 15 | 16 | from .config import KAGGLE_USERNAME, base_path 17 | from .utils import run_shell_command 18 | 19 | 20 | def update_dataset(dataset_id: str, source_path: str) -> None: 21 | """Updates a Kaggle dataset with specific subfolders. 22 | 23 | Args: 24 | dataset_id (str): The ID of the Kaggle dataset (e.g., "username/dataset-name"). 25 | source_path (str): The local path where the dataset files are stored. 26 | """ 27 | original_source_path = Path(source_path) 28 | temp_dir = Path("/tmp/kaggle_dataset_temp/models") 29 | 30 | # Create a temporary directory for uploading 31 | if temp_dir.exists(): 32 | shutil.rmtree(temp_dir) # Clean up if it already exists 33 | temp_dir.mkdir(parents=True) 34 | 35 | # Copy only the desired subfolders to the temporary directory 36 | for folder_name in ["full", "data_processors"]: 37 | src = original_source_path / folder_name 38 | dest = temp_dir / folder_name # Flatten structure if necessary 39 | if src.exists(): 40 | shutil.copytree(src, dest) 41 | else: 42 | print(f"Warning: Folder '{src}' does not exist and won't be uploaded.") 43 | 44 | # Debug: List files in temp_dir to ensure they exist 45 | print(f"Files in temp_dir ({temp_dir}):") 46 | for root, _, files in os.walk(temp_dir): 47 | for file in files: 48 | print(os.path.join(root, file)) 49 | 50 | # Add metadata 51 | metadata = { 52 | "id": f"{KAGGLE_USERNAME}/{dataset_id}", 53 | } 54 | metadata_path = temp_dir / 'dataset-metadata.json' 55 | with open(metadata_path, 'w', encoding='utf-8') as f: 56 | json.dump(metadata, f) 57 | 58 | # Update the Kaggle dataset 59 | command = f""" 60 | kaggle datasets version -p '{temp_dir}' -m 'Updated dataset' -r zip 61 | """ 62 | print(f"Running command: {command}") 63 | os.system(command) 64 | 65 | # Clean up temporary directory 66 | shutil.rmtree(temp_dir) 67 | 68 | 69 | def upload_code(dataset_id: str, source_path: str) -> None: 70 | """Uploads a Python package to a Kaggle dataset. 71 | 72 | Args: 73 | dataset_id (str): The ID of the Kaggle dataset (e.g., "username/dataset-name"). 74 | source_path (str): The local path where the Python package is stored. 75 | 76 | Notes: 77 | This function builds a Python package in the current working directory 78 | using the `setup.py` file, and then uploads it to the specified Kaggle 79 | dataset. 80 | """ 81 | os.chdir(base_path) 82 | run_shell_command("python setup.py sdist bdist_wheel") 83 | 84 | original_source_path = Path(source_path) 85 | temp_dir = tempfile.mkdtemp() 86 | shutil.copy(original_source_path, temp_dir) 87 | source_path = Path(temp_dir) 88 | 89 | metadata = { 90 | "id": f"{KAGGLE_USERNAME}/{dataset_id}", 91 | } 92 | 93 | metadata_path = source_path / 'dataset-metadata.json' 94 | with open(metadata_path, 'w', encoding='utf-8') as f: 95 | json.dump(metadata, f) 96 | 97 | run_shell_command( 98 | f""" 99 | kaggle datasets version -p '{source_path}' -m 'Updated dataset' 100 | """ 101 | ) 102 | 103 | os.remove(metadata_path) 104 | shutil.rmtree(temp_dir) 105 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Tmp 163 | tmp 164 | junk 165 | analysis 166 | 167 | # Data 168 | data/* 169 | archive 170 | models/* 171 | submission.csv 172 | analysis 173 | 174 | # Code 175 | notebooks/junk.ipynb 176 | 177 | # Poetry 178 | poetry.lock 179 | 180 | # Mac 181 | .DS_Store 182 | 183 | # Env 184 | kaggle.json 185 | 186 | # Config 187 | pyrightconfig.json 188 | 189 | # Linter 190 | .pylintrc 191 | 192 | # Wandb 193 | wandb 194 | 195 | # Bash 196 | scripts/bash/ 197 | 198 | # Other 199 | info 200 | solution_sbm.md -------------------------------------------------------------------------------- /janestreet/tracker.py: -------------------------------------------------------------------------------- 1 | """Evaluation and logging functions.""" 2 | 3 | import pandas as pd 4 | import wandb 5 | 6 | from .config import WANDB_PROJECT, base_path 7 | 8 | class WandbTracker: 9 | """Custom class for tracking experiments using WandB. 10 | 11 | This class provides methods for initializing runs, saving features and data, logging metrics, 12 | sending alerts, and updating run summaries and settings. 13 | 14 | Args: 15 | run_name (str): Name of the run. 16 | params (dict): Dictionary containing parameters for the run. 17 | category (str): Category of the run. 18 | comment (str): Comment or description for the run. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | run_name: str, 24 | params: dict, 25 | category: str, 26 | comment: str 27 | ) -> None: 28 | """Initializes the WandbTracker class. 29 | 30 | Args: 31 | run_name (str): Name of the run. 32 | params (dict): Dictionary containing parameters for the run. 33 | category (str): Category of the run. 34 | comment (str): Comment or description for the run. 35 | """ 36 | self.run_name = run_name 37 | self.params = params 38 | self.category = category 39 | self.comment = comment 40 | self.api = wandb.Api() 41 | 42 | def init_run(self, features: list) -> None: 43 | """Initializes a new WandB run. 44 | 45 | Args: 46 | features (list): List of features used in the model. 47 | """ 48 | config = self.params.copy() 49 | config.update({ 50 | "model": "lgb", 51 | "category": self.category, 52 | "comment": self.comment, 53 | "n_features": len(features) 54 | }) 55 | wandb.init( 56 | project=WANDB_PROJECT, 57 | name=self.run_name, 58 | config=config, 59 | dir=base_path, 60 | save_code=True 61 | ) 62 | self.save_features(features) 63 | print(f"Running {self.run_name} model.") 64 | print(self.comment) 65 | 66 | def save_features(self, features: list) -> None: 67 | """Saves the list of features as a WandB artifact. 68 | 69 | Args: 70 | features (list): List of features used in the model. 71 | """ 72 | feature_file_path = "features.txt" 73 | with open(feature_file_path, "w", encoding="utf-8") as file: 74 | for feature in features: 75 | file.write(f"{feature}\n") 76 | artifact = wandb.Artifact( 77 | name=f"{self.run_name}-feature-list", 78 | type="dataset" 79 | ) 80 | artifact.add_file(feature_file_path) 81 | wandb.log_artifact(artifact) 82 | 83 | def save_data(self, df: pd.DataFrame, name: str) -> None: 84 | """Saves a DataFrame as a WandB artifact. 85 | 86 | Args: 87 | df (pd.DataFrame): DataFrame to be saved. 88 | name (str): Name of the artifact. 89 | """ 90 | tab = wandb.Table(columns=list(df.columns), data=df.values.tolist()) 91 | wandb.log({name: tab}) 92 | 93 | def alert(self, text: str) -> None: 94 | """Sends an alert to the user via WandB. 95 | 96 | Args: 97 | text (str): Text of the alert. 98 | """ 99 | wandb.alert( 100 | title=f'Run {self.run_name} finished.', 101 | text=text, 102 | level=wandb.AlertLevel.INFO 103 | ) 104 | 105 | def log_metrics(self, metrics: dict) -> None: 106 | """Logs metrics to the current WandB run. 107 | 108 | Args: 109 | metrics (dict): Dictionary containing the metrics to log. 110 | """ 111 | wandb.log(metrics) 112 | 113 | def update_summary(self, run_id: str, summary_params: dict) -> None: 114 | """Updates the summary of an existing WandB run. 115 | 116 | Args: 117 | run_id (str): ID of the WandB run. 118 | summary_params (dict): Dictionary containing summary parameters to update. 119 | """ 120 | run = self.api.run(f"eivolkova3/kaggle_home_credit/{run_id}") 121 | for key, val in summary_params.items(): 122 | run.summary[key] = val 123 | run.summary.update() 124 | 125 | def update_settings(self, run_id: str, settings_params: dict) -> None: 126 | """Updates the settings of an existing WandB run. 127 | 128 | Args: 129 | run_id (str): ID of the WandB run. 130 | settings_params (dict): Dictionary containing settings parameters to update. 131 | """ 132 | run = self.api.run(f"eivolkova3/kaggle_home_credit/{run_id}") 133 | for key, val in settings_params.items(): 134 | run.settings[key] = val 135 | run.update() 136 | 137 | def finish(self) -> None: 138 | """Finishes the current WandB run. 139 | """ 140 | wandb.finish() 141 | -------------------------------------------------------------------------------- /janestreet/transformers.py: -------------------------------------------------------------------------------- 1 | """Custom transformer for preprocessing data using Polars.""" 2 | 3 | import polars as pl 4 | 5 | class PolarsTransformer: 6 | """A custom transformer for preprocessing data using Polars. 7 | 8 | This transformer provides functionality to scale, fill missing values, 9 | and clip features in a DataFrame. It can fit on a given 10 | dataset to compute statistics and apply transformations accordingly. 11 | 12 | Args: 13 | features (list, optional): List of feature columns to be transformed. Defaults to None. 14 | fillnull (bool, optional): Whether to fill null values with 0. Defaults to True. 15 | scale (bool, optional): Whether to scale the features 16 | by mean and standard deviation. Defaults to True. 17 | clip_time (bool, optional): Whether to clip the "feature_time_id" column 18 | to its min and max values. Defaults to True. 19 | """ 20 | def __init__( 21 | self, 22 | features: list = None, 23 | fillnull: bool = True, 24 | scale: bool = True, 25 | clip_time: bool = True 26 | ) -> None: 27 | """Initializes the PolarsTransformer class. 28 | 29 | Args: 30 | features (list, optional): List of feature columns to be transformed. Defaults to None. 31 | fillnull (bool, optional): Whether to fill null values with 0. Defaults to True. 32 | scale (bool, optional): Whether to scale the features 33 | by mean and standard deviation. Defaults to True. 34 | clip_time (bool, optional): Whether to clip the "feature_time_id" column 35 | to its min and max values. Defaults to True. 36 | """ 37 | self.features = features 38 | self.fillnull = fillnull 39 | self.scale = scale 40 | self.clip_time = clip_time 41 | self.statistics_mean_std = None 42 | self.statistics_min_max = None 43 | 44 | def set_features(self, features: list) -> None: 45 | """Sets the list of features to be transformed. 46 | 47 | Args: 48 | features (list): List of feature columns to be transformed. 49 | """ 50 | self.features = features 51 | 52 | def fit_transform(self, df: pl.DataFrame) -> pl.DataFrame: 53 | """Fits the transformer on the given DataFrame and applies the transformations. 54 | 55 | Args: 56 | df (pl.DataFrame): The input Polars DataFrame. 57 | 58 | Returns: 59 | pl.DataFrame: The transformed Polars DataFrame. 60 | """ 61 | if self.scale: 62 | self.statistics_mean_std = { 63 | column: { 64 | "mean": df[column].mean(), 65 | "std": df[column].std(), 66 | } 67 | for column in self.features 68 | } 69 | 70 | if self.clip_time: 71 | self.statistics_min_max = { 72 | column: { 73 | "min": df[column].min(), 74 | "max": df[column].max(), 75 | } 76 | for column in ["feature_time_id"] 77 | } 78 | 79 | if self.fillnull: 80 | df = df.with_columns([ 81 | pl.col(column).fill_null(0.0) 82 | for column in self.features 83 | ]) 84 | 85 | if self.scale: 86 | df = df.with_columns([ 87 | ((pl.col(column) - self.statistics_mean_std[column]["mean"]) / 88 | self.statistics_mean_std[column]["std"]) 89 | for column in self.features 90 | ]) 91 | 92 | return df 93 | 94 | def transform(self, df: pl.DataFrame, refit: bool = False) -> pl.DataFrame: 95 | """Applies the transformations to the given DataFrame using precomputed statistics. 96 | 97 | Args: 98 | df (pl.DataFrame): The input Polars DataFrame. 99 | refit (bool, optional): If True, updates the min and max values for 100 | the "feature_time_id" column. Defaults to False. 101 | 102 | Returns: 103 | pl.DataFrame: The transformed Polars DataFrame. 104 | """ 105 | if refit: 106 | if self.clip_time: 107 | self.statistics_min_max.update({ 108 | column: { 109 | "min": ( 110 | self.statistics_min_max[column]["min"] 111 | if df[column].min() is None 112 | else min(df[column].min(), self.statistics_min_max[column]["min"]) 113 | ), 114 | "max": ( 115 | self.statistics_min_max[column]["max"] 116 | if df[column].max() is None 117 | else max(df[column].max(), self.statistics_min_max[column]["max"]) 118 | ), 119 | } 120 | for column in ["feature_time_id"] 121 | }) 122 | 123 | if self.clip_time: 124 | df = df.with_columns([ 125 | pl.col(column).clip( 126 | self.statistics_min_max[column]["min"], 127 | self.statistics_min_max[column]["max"] 128 | ) 129 | for column in ["feature_time_id"] 130 | ]) 131 | 132 | if self.fillnull: 133 | df = df.with_columns([ 134 | pl.col(column).fill_null(0.0) 135 | for column in self.features 136 | ]) 137 | 138 | if self.scale: 139 | df = df.with_columns([ 140 | ((pl.col(column) - self.statistics_mean_std[column]["mean"]) / 141 | self.statistics_mean_std[column]["std"]) 142 | for column in self.features 143 | ]) 144 | 145 | return df 146 | -------------------------------------------------------------------------------- /scripts/python/monitor_kaggle.py: -------------------------------------------------------------------------------- 1 | """Monitor Kaggle Submissions and Send Slack Notifications. 2 | 3 | This module monitors Kaggle competition submissions and sends notifications 4 | to a configured Slack channel when new submissions are detected or when 5 | the status of existing submissions changes (e.g., from "pending" to "complete"). 6 | 7 | Features: 8 | - Fetches Kaggle submissions using the Kaggle API. 9 | - Converts submission timestamps from UTC to the local time zone. 10 | - Tracks submission statuses and execution times. 11 | - Sends real-time notifications to Slack for new submissions and status updates. 12 | - Automatically handles Slack errors and retries. 13 | 14 | Environment Variables: 15 | SLACK_WEBHOOK_URL (str): Slack webhook URL for sending notifications. 16 | KAGGLE_COMPETITION_ID (str): ID of the Kaggle competition to monitor. 17 | 18 | Requirements: 19 | - Python 3.6+ 20 | - `requests` library for sending Slack notifications. 21 | - `dotenv` library for loading environment variables. 22 | - `pytz` library for time zone conversion. 23 | - Kaggle API configured and authenticated via CLI. 24 | 25 | Usage: 26 | 1. Set up the `.env` file with the required environment variables. 27 | 2. Ensure the Kaggle API is installed and authenticated. 28 | 3. Run the script to start monitoring submissions. 29 | 30 | Example `.env` file: 31 | SLACK_WEBHOOK_URL=https://hooks.slack.com/services/your/webhook/url 32 | KAGGLE_COMPETITION_ID=your-competition-id 33 | """ 34 | 35 | 36 | import os 37 | import time 38 | import requests 39 | import csv 40 | import subprocess 41 | from datetime import datetime 42 | from dotenv import load_dotenv 43 | import pytz 44 | 45 | load_dotenv() 46 | 47 | SLACK_WEBHOOK_URL = os.environ.get("SLACK_WEBHOOK_URL") 48 | submission_statuses = {} 49 | first_iteration = True 50 | local_timezone = pytz.timezone('Europe/London') 51 | 52 | 53 | def send_slack_notification(message): 54 | """ 55 | Sends a notification message to Slack using the configured webhook URL. 56 | 57 | Args: 58 | message (str): The message to be sent to Slack. 59 | 60 | Returns: 61 | None 62 | """ 63 | response = requests.post(SLACK_WEBHOOK_URL, json={"text": message}) 64 | if response.status_code != 200: 65 | print(f"Slack API returned an error. HTTP Status: {response.status_code}") 66 | else: 67 | print("Slack notification sent successfully.") 68 | 69 | 70 | def convert_to_local_time(date_time_str): 71 | """ 72 | Converts a UTC date-time string to the local time zone. 73 | 74 | Args: 75 | date_time_str (str): UTC date-time string in the format '%Y-%m-%d %H:%M:%S'. 76 | 77 | Returns: 78 | str: Localized date-time string in the format '%Y-%m-%d %H:%M:%S'. 79 | """ 80 | date_time_obj = datetime.strptime(date_time_str, "%Y-%m-%d %H:%M:%S") 81 | utc_time = pytz.utc.localize(date_time_obj) 82 | local_time = utc_time.astimezone(local_timezone) 83 | return local_time.strftime("%Y-%m-%d %H:%M:%S") 84 | 85 | 86 | def local_time_to_timestamp(local_time_str): 87 | """ 88 | Converts a local time string to a Unix timestamp. 89 | 90 | Args: 91 | local_time_str (str): Local time string in the format '%Y-%m-%d %H:%M:%S'. 92 | 93 | Returns: 94 | int: Unix timestamp. 95 | """ 96 | local_time_obj = datetime.strptime(local_time_str, "%Y-%m-%d %H:%M:%S") 97 | return int(local_time_obj.timestamp()) 98 | 99 | 100 | def check_submissions(): 101 | """ 102 | Checks the status of Kaggle submissions and sends Slack notifications for new or updated submissions. 103 | 104 | Returns: 105 | None 106 | """ 107 | global first_iteration 108 | current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 109 | print("-" * 70) 110 | print(f"{current_time}: Fetching Kaggle submissions...") 111 | 112 | kaggle_competition_id = os.getenv("KAGGLE_COMPETITION_ID") 113 | result = subprocess.run([ 114 | 'kaggle', 'competitions', 'submissions', 115 | '-c', kaggle_competition_id, '--csv' 116 | ], stdout=subprocess.PIPE, universal_newlines=True) 117 | submissions = result.stdout 118 | 119 | reader = csv.reader(submissions.splitlines()) 120 | next(reader) 121 | 122 | for row in reader: 123 | fileName, date, description, status, publicScore, privateScore = row 124 | if date == "date": 125 | continue 126 | 127 | date_time = convert_to_local_time(date) 128 | timestamp = local_time_to_timestamp(date_time) 129 | 130 | print(f"{current_time}: {fileName}, {date_time}, {description}, {status}, {publicScore}, {timestamp}") 131 | 132 | if timestamp not in submission_statuses: 133 | if first_iteration: 134 | submission_statuses[timestamp] = status 135 | else: 136 | submission_statuses[timestamp] = status 137 | notification_message = (f"*{current_time}* :rocket: *New Kaggle Submission Detected*\n" 138 | f"> *Submission:* `{description}`\n" 139 | f"> *Submitted at:* `{date_time}`\n" 140 | f"> *Status:* `{status}`\n" 141 | f"> *Score:* `{publicScore}`") 142 | print(f"{current_time}: New submission detected: {description} at {date_time}. Status: {status}.") 143 | send_slack_notification(notification_message) 144 | else: 145 | if status == "complete" and submission_statuses.get(timestamp) == "pending": 146 | submission_statuses[timestamp] = "complete" 147 | current_timestamp = int(time.time()) 148 | execution_time = current_timestamp - timestamp 149 | execution_hours = int(execution_time // 3600) 150 | execution_minutes = int((execution_time % 3600) // 60) 151 | notification_message = (f"*{current_time}* :bell: *Kaggle Submission Update*\n" 152 | f"> *Submission:* `{description}`\n" 153 | f"> *Submitted at:* `{date_time}`\n" 154 | f"> *Completed at:* `{current_time}`\n" 155 | f"> *Status:* `{status}`\n" 156 | f"> *Score:* `{publicScore}`\n" 157 | f"> *Time taken:* `{execution_hours}h {execution_minutes}m`") 158 | print( 159 | f"{current_time}: Submission at {date_time} " 160 | f"(timestamp: {timestamp}) with description {description} " 161 | f"has changed status to {status}. Time taken: {execution_hours}h {execution_minutes}m." 162 | ) 163 | send_slack_notification(notification_message) 164 | 165 | submission_statuses[timestamp] = status 166 | 167 | if first_iteration: 168 | first_iteration = False 169 | 170 | 171 | def main(): 172 | """ 173 | Main function to monitor Kaggle submissions and send notifications. 174 | 175 | Returns: 176 | None 177 | """ 178 | load_dotenv() 179 | send_slack_notification(":mag: *Monitoring Kaggle submissions...*") 180 | 181 | while True: 182 | check_submissions() 183 | time.sleep(60) 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | -------------------------------------------------------------------------------- /solution.md: -------------------------------------------------------------------------------- 1 | # Solution description 2 | 3 | ## 1. Cross-validation 4 | 5 | I used a time-series CV with two folds. The validation size was set to 200 dates, as in the public dataset. It correlated well with the public LB scores. Additionally, the model from the first fold was tested on the last 200 dates with a 200-day gap to simulate the private dataset scenario. 6 | 7 | ## 2. Feature engineering and data preparation 8 | 9 | ## 2.1 Sample 10 | 11 | I used data starting from `date_id = 700`, as this is when the number of `time_id`s stabilizes at 968. I experimented with using the entire dataset, but it did not result in any score improvement. 12 | 13 | ## 2.2 Data preparation 14 | 15 | Simple standardization and NaN imputation with zero were applied. Other methods didn't provide any improvement. 16 | 17 | ## 2.3 Feature engeneering 18 | 19 | I used all original features except for three categorical ones (features 09–11). I also selected 16 features that showed a high correlation with the target and created two groups of additional features: 20 | 21 | - Market averages: Averages per `date_id` and `time_id`. 22 | - Rolling statistics: Rolling averages and standard deviations over the last 1000 `time_id`s for each symbol. 23 | 24 | Besides that, I added `time_id` as a feature. 25 | 26 | Adding these features resulted in an improvement of about +0.002 on CV. 27 | 28 | ## 3. Model architecture 29 | 30 | ## 3.1 Base model 31 | 32 | Time-series GRU with sequence equal to one day. I ended up with two slightly different architectures: 33 | 34 | - 3-layer GRU 35 | - 1-layer GRU followed by 2 linear layers with ReLU activation and dropout. 36 | 37 | The second model worked better than the first model on CV (+0.001), but the first model still contributed to the ensemble, so I kept it. 38 | 39 | MLP, time-series transformers, cross-symbol attention and embeddings didn't work for me. 40 | 41 | ### 3.2 Responders 42 | 43 | I used 4 responders as auxiliary targets: `responder_7` and `responder_8`, and two calculated ones: 44 | 45 | ```python 46 | df = df.with_columns( 47 | ( 48 | pl.col("responder_8") 49 | + pl.col("responder_8").shift(-4).over("symbol_id") 50 | ).fill_null(0.0).alias("responder_9"), 51 | ( 52 | pl.col("responder_6") 53 | + pl.col("responder_6").shift(-20).over("symbol_id") 54 | + pl.col("responder_6").shift(-40).over("symbol_id") 55 | ).fill_null(0.0).alias("responder_10"), 56 | ) 57 | ``` 58 | 59 | These are approximate rolling averages of the base target over 8 and 60 days, respectively. As described in detail in [this discussion](https://www.kaggle.com/competitions/jane-street-real-time-market-data-forecasting/discussion/555562) by @johnpayne0, `responder_6` is a 20-day rolling average of some variable, while `responder_7` and `responder_8` are 120-day and 4-day rolling averages of the same variable, with some added noise. Given an N-day rolling average, we can easily calculate N*K-day rolling averages. 60 | 61 | A separate base model was used for each auxiliary target. The predictions from these models were then passed through a linear layer to produce the final target output, `responder_6`. 62 | 63 | The sum of losses (weighted zero-mean R²) for each responder was used to train the model. 64 | 65 | Adding auxiliary targets improved both CV and LB scores by about +0.001. 66 | 67 | Models were trained using a batch size of one day, with a learning rate of 0.0005. 68 | For submission, I trained models on data up to the last date_id, using the number of epochs equal to the average optimal number of epochs on CV. 69 | 70 | ### 3.3 Ensemble 71 | 72 | I ran both models on 3 seeds and took a simple unweighted average of predictions from those 6 models. This resulted in an LB score of 0.0112 (vs best single model LB 0.0105). 73 | 74 | ## 4. Online Learning 75 | 76 | During inference, when new data with targets becomes available, I perform one forward pass to update the model weights with a learning rate of 0.0003. This approach significantly improved the model’s performance on CV (+0.008). Interestingly, for an MLP model, the score without online learning was higher than for the GRU, but lower with online learning. 77 | 78 | Updates are performed only with the `responder_6` loss, without auxiliary targets. 79 | 80 | Updates are applied for the entire dataset provided during submission, including rows with is_scored = False. 81 | 82 | I also considered performing a full online retraining on the data up to the start of the private dataset. This would make sense because there is a significant gap between the training data and the private dataset. However, retraining the model would require distributing the training process across multiple inference steps, as the one-minute time limit between dates would not be sufficient. I believe this would have been feasible but I decided not to spend time on it, although my tests suggested that it could provide a +0.001 improvement in the score. Still, I find it amazing that, instead of a full model retraining, performing one-day updates for almost a year is enough, and the model continues to perform well. 83 | 84 | ## 5. Technical details 85 | 86 | ### 5.1 Inference Speed 87 | 88 | Inference speed was critically important, so I spent a significant amount of time optimizing my code, particularly data processing and calculation of rolling features. 89 | 90 | For my final submission, it takes 0.06 seconds to run one inference step (`time_id`), 0.02 of which are spent on data processing. Updating model weights once per `date_id` takes 3.6 seconds. 91 | 92 | I used PyTorch, but since TensorFlow is said to be faster, I tried switching to it. However, after a few days of experimenting, I couldn't achieve better performance, so I decided to stick with PyTorch. 93 | 94 | ### 5.2 Technical stack 95 | 96 | Due to RAM requirements, I switched from Google Colab to vast.ai and was extremely happy with the decision. I wrote code locally, enjoying all the perks of VSCode, and then ran a script to push the code to github, pull it on the server and execute scripts remotely. 97 | 98 | I also used WandB to monitor experiments, which helped me keep track of scores and easily revert to an older version of the code if something went wrong. 99 | 100 | To debug my submission notebook and estimate submission time I used [synthetic dataset](https://www.kaggle.com/code/shiyili/js24-rmf-submission-api-debug-with-synthetic-test) by @shiyili. 101 | 102 | ### 6. Scores 103 | 104 | | | CV fold 0 | CV fold 1v | Fold 1 with 200 days gap | CV avg | 105 | | -------------------------------------------------------- | --------- | --------- | ------------------------ | ------ | 106 | | GRU 1 without both auxiliary targets and online learning | 0.0161 | 0.0062 | 0.0011 | 0.0112 | 107 | | GRU 1 without auxiliary targets | 0.0235 | 0.0148 | 0.0136 | 0.0190 | 108 | | GRU 1 | 0.0249 | 0.0153 | 0.0147 | 0.0201 | 109 | | GRU 2 | 0.0262 | 0.0166 | 0.0161 | 0.0214 | 110 | | GRU 1 + GRU 2 | 0.0268 | 0.0169 | 0.0163 | 0.0218 | 111 | | GRU 1 3 seeds | 0.0258 | 0.0164 | 0.0152 | 0.0211 | 112 | | GRU 2 3 seeds | 0.0267 | 0.0175 | 0.0163 | 0.0221 | 113 | | GRU 1 + GRU 2 3 seeds | 0.0270 | 0.0175 | 0.0162 | 0.0222 | 114 | 115 | Fold 0: `date_id`s from 1298 to 1498. 116 | Fold 1: `date_id`s from 1499 to 1698. 117 | -------------------------------------------------------------------------------- /janestreet/data_processor.py: -------------------------------------------------------------------------------- 1 | """Custom data processor for feature engineering and transformation. 2 | """ 3 | 4 | import os 5 | import joblib 6 | import polars as pl 7 | 8 | from janestreet.config import PATH_DATA, PATH_MODELS 9 | from janestreet import utils 10 | from janestreet.transformers import PolarsTransformer 11 | 12 | class DataProcessor: 13 | """Custom data processor for feature engineering and transformation. 14 | 15 | This class handles loading, processing, and transforming data for training and testing. 16 | It includes methods for adding features such as rolling averages, standard deviations, 17 | and market averages. 18 | 19 | Attributes: 20 | PATH (str): Path to save and load data processors. 21 | COLS_FEATURES_INIT (list[str]): Initial feature columns. 22 | COLS_FEATURES_CORR (list[str]): Correlated feature columns for additional processing. 23 | COLS_FEATURES_CAT (list[str]): Categorical feature columns. 24 | T (int): Window size for rolling computations. 25 | name (str): Name of the data processor. 26 | skip_days (int or None): Number of days to skip when loading data. 27 | transformer (PolarsTransformer or None): Transformer for data preprocessing. 28 | features (list[str]): List of feature columns after processing. 29 | """ 30 | PATH = os.path.join(PATH_MODELS, "data_processors") 31 | 32 | COLS_FEATURES_INIT = [f"feature_{i:02d}" for i in range(79)] 33 | 34 | COLS_FEATURES_CORR = [ 35 | 'feature_06', 36 | 'feature_04', 37 | 'feature_07', 38 | 'feature_36', 39 | 'feature_60', 40 | 'feature_45', 41 | 'feature_56', 42 | 'feature_05', 43 | 'feature_51', 44 | 'feature_19', 45 | 'feature_66', 46 | 'feature_59', 47 | 'feature_54', 48 | 'feature_70', 49 | 'feature_71', 50 | 'feature_72', 51 | ] 52 | COLS_FEATURES_CAT = ["feature_09", "feature_10", "feature_11"] 53 | 54 | T = 1000 55 | 56 | def __init__( 57 | self, 58 | name: str, 59 | skip_days: int = None, 60 | transformer: PolarsTransformer | None = None 61 | ): 62 | """Initializes the DataProcessor. 63 | 64 | Args: 65 | name (str): Name of the data processor. 66 | skip_days (int, optional): Number of days to skip when loading data. Defaults to None. 67 | transformer (PolarsTransformer, optional): Transformer for data preprocessing. 68 | Defaults to None. 69 | """ 70 | self.name = name 71 | self.skip_days = skip_days 72 | self.transformer = transformer 73 | 74 | self.features = list(self.COLS_FEATURES_INIT) 75 | self.features += [f"{i}_diff_rolling_avg_{self.T}" for i in self.COLS_FEATURES_CORR] 76 | self.features += [f"{i}_rolling_std_{self.T}" for i in self.COLS_FEATURES_CORR] 77 | self.features += [f"{i}_avg_per_date_time" for i in self.COLS_FEATURES_CORR] 78 | self.features += ["feature_time_id"] 79 | self.features = [i for i in self.features if i not in self.COLS_FEATURES_CAT] 80 | 81 | utils.create_folder(self.PATH) 82 | 83 | def get_train_data(self) -> pl.DataFrame: 84 | """Loads, processes, and returns training data. 85 | 86 | Returns: 87 | pl.DataFrame: Processed training data. 88 | """ 89 | df = self._load_data().collect() 90 | 91 | # Additional responders 92 | # (8- and 60-days moving average) 93 | df = df.with_columns( 94 | ( 95 | pl.col("responder_8") 96 | + pl.col("responder_8").shift(-4).over("symbol_id") 97 | ).fill_null(0.0).alias("responder_9"), 98 | ( 99 | pl.col("responder_6") 100 | + pl.col("responder_6").shift(-20).over("symbol_id") 101 | + pl.col("responder_6").shift(-40).over("symbol_id") 102 | ).fill_null(0.0).alias("responder_10"), 103 | ) 104 | 105 | df = self._add_features(df) 106 | 107 | if self.transformer is not None: 108 | self.transformer.set_features(self.features) 109 | df = self.transformer.fit_transform(df) 110 | 111 | self._save() 112 | return df 113 | 114 | def process_test_data( 115 | self, 116 | df: pl.DataFrame, 117 | fast: bool = False, 118 | date_id: int = 0, 119 | time_id: int = 0, 120 | symbols: list = None 121 | ) -> pl.DataFrame: 122 | """Processes test data. 123 | 124 | Args: 125 | df (pl.DataFrame): DataFrame containing test data. 126 | fast (bool, optional): Whether to use fast processing mode. Defaults to False. 127 | date_id (int, optional): Current date id. Defaults to 0. 128 | time_id (int, optional): Current time id. Defaults to 0. 129 | symbols (list, optional): List of symbols to process. Defaults to None. 130 | 131 | Returns: 132 | pl.DataFrame: Processed test data. 133 | """ 134 | df = self._add_features(df, fast=fast, date_id=date_id, time_id=time_id, symbols=symbols) 135 | if self.transformer is not None: 136 | df = self.transformer.transform(df, refit=True) 137 | return df 138 | 139 | def _save(self): 140 | """Saves the data processor to disk. 141 | """ 142 | joblib.dump(self, f"{self.PATH}/{self.name}.joblib") 143 | 144 | def load(self): 145 | """Loads the data processor from disk. 146 | 147 | Returns: 148 | DataProcessor: Loaded data processor. 149 | """ 150 | return joblib.load(f"{self.PATH}/{self.name}.joblib") 151 | 152 | def _load_data(self) -> pl.DataFrame: 153 | """Loads the training data from disk. 154 | 155 | Returns: 156 | pl.DataFrame: Loaded training data. 157 | """ 158 | df = pl.scan_parquet(f'{PATH_DATA}/train.parquet') 159 | df = df.drop("partition_id") 160 | if self.skip_days is not None: 161 | df = df.filter(pl.col("date_id")>=self.skip_days) 162 | return df 163 | 164 | def _add_features( 165 | self, 166 | df: pl.DataFrame, 167 | fast: bool = False, 168 | date_id: int | None = None, 169 | time_id: int | None = None, 170 | symbols: list = None 171 | ) -> pl.DataFrame: 172 | """Adds features to the data. 173 | 174 | Args: 175 | df (pl.DataFrame): DataFrame to process. 176 | fast (bool, optional): Whether to use fast processing mode. Defaults to False. 177 | date_id (int, optional): Current date ID. Defaults to None. 178 | time_id (int, optional): Current time ID. Defaults to None. 179 | symbols (list, optional): List of symbols to process. Defaults to None. 180 | 181 | Returns: 182 | pl.DataFrame: DataFrame with added features. 183 | """ 184 | df = self._get_window_average_std( 185 | df, 186 | self.COLS_FEATURES_CORR, 187 | n=self.T, 188 | fast=fast, 189 | date_id=date_id, 190 | time_id=time_id, 191 | symbols=symbols 192 | ) 193 | df = self._get_market_average(df, self.COLS_FEATURES_CORR, fast=fast) 194 | 195 | df = df.with_columns( 196 | pl.col("time_id").alias("feature_time_id"), 197 | 198 | ) 199 | return df 200 | 201 | def _get_window_average_std( 202 | self, 203 | df: pl.DataFrame, 204 | cols: list, 205 | n: int = 1000, 206 | fast: bool = False, 207 | date_id: int | None = None, 208 | time_id: int | None = None, 209 | symbols: list = None 210 | ) -> pl.DataFrame: 211 | """Computes rolling averages and standard deviations. 212 | 213 | Args: 214 | df (pl.DataFrame): DataFrame to process. 215 | cols (list): List of columns for which to compute rolling statistics. 216 | n (int, optional): Window size. Defaults to 1000. 217 | fast (bool, optional): Whether to use fast processing mode. Defaults to False. 218 | If True, date_id, time_id and symbols args should be set. 219 | date_id (int, optional): Current date ID. Defaults to None. 220 | time_id (int, optional): Current time ID. Defaults to None. 221 | symbols (list, optional): List of symbols to process. Defaults to None. 222 | 223 | Returns: 224 | pl.DataFrame: DataFrame with rolling averages and standard deviations. 225 | """ 226 | if not fast: 227 | df = df.with_columns([ 228 | pl.col(col).rolling_mean(window_size=n) 229 | .over(["symbol_id"]).alias(f"{col}_rolling_avg_{n}") 230 | for col in cols 231 | ] + [ 232 | pl.col(col).rolling_std(window_size=n) 233 | .over(["symbol_id"]).alias(f"{col}_rolling_std_{n}") 234 | for col in cols 235 | ]) 236 | else: 237 | df = df.group_by("symbol_id").agg([ 238 | pl.col(col).mean().alias(f"{col}_rolling_avg_{n}") 239 | for col in cols 240 | ] + [ 241 | pl.col(col).std().alias(f"{col}_rolling_std_{n}") 242 | for col in cols 243 | ] + [ 244 | pl.col(col).last().alias(col) 245 | for col in self.COLS_FEATURES_INIT + ["row_id", "weight", "is_scored"] 246 | ]).filter(pl.col("symbol_id").is_in(symbols)) 247 | df = df.with_columns( 248 | pl.lit(date_id).cast(pl.Int16).alias("date_id"), 249 | pl.lit(time_id).cast(pl.Int16).alias("time_id") 250 | ) 251 | 252 | df = df.with_columns([ 253 | (pl.col(col) - pl.col(f"{col}_rolling_avg_{n}")).alias(f"{col}_diff_rolling_avg_{n}") 254 | for col in cols 255 | ]) 256 | df = df.drop([f"{col}_rolling_avg_{n}" for col in cols]) 257 | return df 258 | 259 | 260 | def _get_market_average( 261 | self, 262 | df: pl.DataFrame, 263 | cols: list, 264 | fast: bool = False 265 | ) -> pl.DataFrame: 266 | """Computes market averages (average per date_id and time_id). 267 | 268 | Args: 269 | df (pl.DataFrame): DataFrame to process. 270 | cols (list): List of columns for which to compute market averages. 271 | fast (bool, optional): Whether to use fast processing mode. Defaults to False. 272 | 273 | Returns: 274 | pl.DataFrame: DataFrame with market averages. 275 | """ 276 | if not fast: 277 | df = df.with_columns([ 278 | pl.col(col) 279 | .mean().over(["date_id", "time_id"]) 280 | .alias(f"{col}_avg_per_date_time") 281 | for col in cols 282 | ]) 283 | else: 284 | df = df.with_columns([ 285 | pl.col(col) 286 | .mean() 287 | .alias(f"{col}_avg_per_date_time") 288 | for col in cols 289 | ]) 290 | return df 291 | -------------------------------------------------------------------------------- /janestreet/pipeline.py: -------------------------------------------------------------------------------- 1 | """Custom pipeline classes. 2 | 3 | These classes are designed to facilitate model training, cross-validation, and testing 4 | on time series data, with additional functionalities for model management. 5 | 6 | Classes: 7 | FullPipeline: A custom pipeline for model training, saving, loading, and updating. 8 | PipelineEnsemble: An ensemble pipeline that aggregates predictions from multiple models. 9 | PipelineCV: A cross-validation pipeline designed for time series data. 10 | """ 11 | 12 | import copy 13 | import gc 14 | import os 15 | 16 | 17 | import joblib 18 | import numpy as np 19 | import polars as pl 20 | 21 | from tqdm.auto import tqdm 22 | 23 | from sklearn.base import clone 24 | from sklearn.model_selection import TimeSeriesSplit 25 | from sklearn.base import BaseEstimator 26 | 27 | import torch 28 | 29 | from .config import PATH_MODELS, COL_TARGET, COL_ID, COL_DATE, COL_TIME, COL_WEIGHT, COLS_RESPONDERS 30 | from . import utils 31 | from .metrics import r2_weighted 32 | from .tracker import WandbTracker 33 | 34 | 35 | TEST_SIZE = 200 36 | GAP = 0 37 | 38 | 39 | class FullPipeline: 40 | """Custom pipeline for model management and time series training. 41 | 42 | This class provides methods for fitting, predicting, updating, saving, and loading models. 43 | 44 | Attributes: 45 | model (BaseEstimator): The model to be used. 46 | preprocessor: Optional data preprocessing pipeline. 47 | name (str): Name of the model for saving/loading purposes. 48 | load_model (bool): Flag indicating whether to load the model from disk. 49 | features (list[str] or None): List of feature names to use. 50 | save_to_disc (bool): Flag indicating whether to save the model to disk. 51 | refit (bool): Flag indicating whether to refit the model during updates. 52 | change_lr (bool): Flag indicating whether to change learning rate (if load is True). 53 | col_target (str): Name of the target column. 54 | """ 55 | def __init__( 56 | self, 57 | model: BaseEstimator, 58 | preprocessor = None, 59 | run_name: str = "", 60 | name: str = "", 61 | load_model: bool = False, 62 | features: list | None = None, 63 | save_to_disc: bool = True, 64 | refit = True, 65 | change_lr = False, 66 | col_target = COL_TARGET, 67 | ) -> None: 68 | """Initializes the FullPipeline. 69 | 70 | Args: 71 | model (BaseEstimator): The model to be used. 72 | preprocessor: Optional preprocessing pipeline. 73 | run_name (str): Name of the current run. 74 | name (str): Name of the model for saving/loading purposes. 75 | load_model (bool): Whether to load the model from disk. 76 | features (list[str] or None): List of feature names to use. 77 | save_to_disc (bool): Whether to save the model to disk. 78 | refit (bool): Whether to refit the model during updates. 79 | change_lr (bool): Whether to change learning rate (if load is True). 80 | col_target (str): Name of the target column. 81 | """ 82 | self.model = model 83 | self.preprocessor = preprocessor 84 | self.name = name 85 | self.load_model = load_model 86 | self.features = features 87 | self.save_to_disc = save_to_disc 88 | self.refit = refit 89 | self.change_lr = change_lr 90 | self.col_target = col_target 91 | 92 | self.responders = [i for i in COLS_RESPONDERS if i != self.col_target] 93 | 94 | self.set_run_name(run_name) 95 | self.path = os.path.join(PATH_MODELS, f"{self.run_name}") 96 | 97 | def set_run_name(self, run_name: str) -> None: 98 | """Sets the run name for the model. 99 | 100 | Args: 101 | run_name (str): The name of the run. 102 | 103 | """ 104 | self.run_name = run_name 105 | self.path = os.path.join(PATH_MODELS, f"{self.run_name}") 106 | if self.save_to_disc: 107 | utils.create_folder(self.path) 108 | 109 | def fit( 110 | self, 111 | df: pl.DataFrame | None = None, 112 | df_valid: pl.DataFrame | None = None, 113 | verbose: bool = False, 114 | ) -> None: 115 | """Fits the model pipeline. 116 | 117 | Args: 118 | df (pl.DataFrame | None): DataFrame containing training data. 119 | df_valid (pl.DataFrame | None): DataFrame containing validation data. 120 | verbose (bool): Whether to enable verbose output during fitting. 121 | 122 | """ 123 | if not self.load_model: 124 | self.model.features = self.features 125 | 126 | weights_train = df.select(COL_WEIGHT).to_series().to_numpy() 127 | dates_train = df.select(COL_DATE).to_series().to_numpy() 128 | times_train = df.select(COL_TIME).to_series().to_numpy() 129 | stocks_train = df.select(COL_ID).to_series().to_numpy() 130 | 131 | weights_valid = df_valid.select(COL_WEIGHT).to_series().to_numpy() 132 | dates_valid = df_valid.select(COL_DATE).to_series().to_numpy() 133 | times_valid = df_valid.select(COL_TIME).to_series().to_numpy() 134 | stocks_valid = df_valid.select(COL_ID).to_series().to_numpy() 135 | 136 | if self.preprocessor is not None: 137 | df = self.preprocessor.fit_transform(df) 138 | df_valid = self.preprocessor.transform(df_valid) 139 | 140 | X_train = df.select(self.features).to_numpy() 141 | resp_train = df.select(self.responders).to_numpy() 142 | y_train = df.select(self.col_target).to_series().to_numpy() 143 | 144 | X_valid = df_valid.select(self.features).to_numpy() 145 | resp_valid = df_valid.select(self.responders).to_numpy() 146 | y_valid = df_valid.select(self.col_target).to_series().to_numpy() 147 | 148 | train_set = ( 149 | X_train, 150 | resp_train, 151 | y_train, 152 | weights_train, 153 | stocks_train, 154 | dates_train, 155 | times_train 156 | ) 157 | val_set = ( 158 | X_valid, 159 | resp_valid, 160 | y_valid, 161 | weights_valid, 162 | stocks_valid, 163 | dates_valid, 164 | times_valid 165 | ) 166 | 167 | del df, df_valid 168 | gc.collect() 169 | 170 | self.model.fit(train_set, val_set, verbose) 171 | if self.save_to_disc: 172 | self.save() 173 | else: 174 | self.load() 175 | 176 | def predict( 177 | self, 178 | df: pl.DataFrame, 179 | hidden: torch.Tensor | list | None = None, 180 | n_times: int | None = None 181 | ) -> tuple[np.ndarray, torch.Tensor | list]: 182 | """Predicts target using the fitted model. 183 | 184 | Args: 185 | df (pl.DataFrame): DataFrame containing data for prediction. 186 | hidden (torch.Tensor | list | None): Hidden states for recurrent models. 187 | n_times (int | None): Number of time steps to predict. 188 | 189 | Returns: 190 | tuple[np.ndarray, torch.Tensor | list]: Predicted probabilities and hidden states. 191 | 192 | """ 193 | if n_times is None: 194 | n_times = len(df.select(COL_TIME).unique()) 195 | if self.preprocessor is not None: 196 | df = self.preprocessor.transform(df) 197 | X = df.select(self.features).to_numpy() 198 | preds, hidden = self.model.predict(X, hidden=hidden, n_times=n_times) 199 | preds = np.clip(preds, -5, 5) 200 | return preds, hidden 201 | 202 | def update(self, df: pl.DataFrame) -> None: 203 | """Updates model weights using new data. 204 | 205 | Args: 206 | df (pl.DataFrame): DataFrame containing data for updating the model. 207 | 208 | """ 209 | weights = df.select(COL_WEIGHT).to_series().to_numpy() 210 | n_times = len(df.select(COL_TIME).unique()) 211 | if self.preprocessor is not None: 212 | df = self.preprocessor.transform(df, refit=True) 213 | 214 | X = df.select(self.features).to_numpy() 215 | y = df.select(self.col_target).to_series().to_numpy() 216 | self.model.update(X, y, weights, n_times) 217 | 218 | def load(self) -> None: 219 | """Loads the model from disk.""" 220 | if self.change_lr: 221 | lr_refit = self.model.lr_refit 222 | self.model = joblib.load(f"{self.path}/model_{self.name}.joblib") 223 | self.features = self.model.features 224 | if self.change_lr: 225 | self.model.lr_refit = lr_refit 226 | try: 227 | self.preprocessor = joblib.load(f"{self.path}/preprocessor_{self.name}.joblib") 228 | except FileNotFoundError: 229 | self.preprocessor = None 230 | print("WARNING: Preprocessor not found.") 231 | 232 | def save(self) -> None: 233 | """Saves the model to disk.""" 234 | 235 | joblib.dump(self.model, f"{self.path}/model_{self.name}.joblib") 236 | if self.preprocessor is not None: 237 | joblib.dump(self.preprocessor, f"{self.path}/preprocessor_{self.name}.joblib") 238 | 239 | def get_params(self, deep: bool = True) -> dict: 240 | """Gets parameters for the estimator. 241 | 242 | Args: 243 | deep (bool): Whether to return the parameters of sub-objects. 244 | 245 | Returns: 246 | dict: Dictionary of parameters. 247 | """ 248 | return { 249 | "model": self.model, 250 | "preprocessor": self.preprocessor, 251 | "name": self.name, 252 | "load_model": self.load_model, 253 | "features": self.features, 254 | "save_to_disc": self.save_to_disc, 255 | "refit": self.refit, 256 | "change_lr": self.change_lr, 257 | "col_target": self.col_target, 258 | } 259 | 260 | def set_params(self, **parameters): 261 | """Sets the parameters of the estimator. 262 | 263 | Args: 264 | parameters: A dictionary of parameter names and values. 265 | 266 | Returns: 267 | self: The updated estimator. 268 | """ 269 | for parameter, value in parameters.items(): 270 | setattr(self, parameter, value) 271 | return self 272 | 273 | 274 | class PipelineEnsemble: 275 | """Ensemble pipeline for aggregating predictions from multiple models. 276 | 277 | This class manages multiple models, allowing for fitting, prediction, updating, 278 | and managing ensemble weights. It is designed for time series data, where 279 | different models can be combined to improve overall prediction performance. 280 | 281 | Attributes: 282 | models (list): List of models to be used in the ensemble. 283 | weights (np.ndarray): Array of weights for averaging predictions. 284 | refit_models (list[bool]): Flags indicating whether each model 285 | should be refit during updates. 286 | col_target (str): Name of the target column. 287 | """ 288 | def __init__( 289 | self, 290 | models: list, 291 | weights: np.array = None, 292 | refit_models: list[bool] = None, 293 | col_target: str = COL_TARGET 294 | ) -> None: 295 | """Initializes the PipelineEnsemble. 296 | 297 | Args: 298 | models (list): List of models to be used in the ensemble. 299 | weights (numpy array or None): Weights for averaging model predictions. 300 | refit_models (list[bool] or None): Flags for refitting models during updates. 301 | col_target (str): Name of the target column. 302 | """ 303 | self.models = models 304 | self.weights = weights if weights is not None else np.ones(len(self.models)) 305 | self.refit_models = refit_models if refit_models is not None else [True]*len(models) 306 | self.col_target = col_target 307 | self.refit = True 308 | 309 | def fit( 310 | self, 311 | df: pl.DataFrame | None = None, 312 | df_valid: pl.DataFrame | None = None, 313 | verbose: bool = False, 314 | ) -> None: 315 | """Fits all models in the ensemble. 316 | 317 | Args: 318 | df (pl.DataFrame | None): DataFrame containing the training data. 319 | df_valid (pl.DataFrame | None): DataFrame containing the validation data. 320 | verbose (bool): Enables verbose output during fitting. 321 | 322 | """ 323 | self.weights = np.array(self.weights) / sum(self.weights) 324 | for model in self.models: 325 | model.fit(df, df_valid, verbose) 326 | 327 | def set_run_name(self, run_name: str) -> None: 328 | """Sets the run name for all models in the ensemble. 329 | 330 | Args: 331 | run_name (str): The name of the run. 332 | 333 | """ 334 | for model in self.models: 335 | model.set_run_name(run_name) 336 | 337 | def predict(self, df: pl.DataFrame, hidden_ls=None) -> np.ndarray: 338 | """Predicts probabilities using all models in the ensemble. 339 | 340 | Args: 341 | df (pl.DataFrame): DataFrame containing the data for prediction. 342 | hidden_ls (list or None): List of hidden states for each model. 343 | 344 | Returns: 345 | tuple[np.ndarray, list]: Averaged predictions and updated hidden states. 346 | """ 347 | if hidden_ls is None: 348 | hidden_ls = [None] * len(self.models) 349 | 350 | preds = [] 351 | for i, model in enumerate(self.models): 352 | preds_i, hidden_ls[i] = model.predict(df, hidden=hidden_ls[i]) 353 | preds.append(preds_i) 354 | 355 | preds = np.average(preds, axis=0, weights=self.weights) 356 | return preds, hidden_ls 357 | 358 | def update(self, df: pl.DataFrame) -> None: 359 | """Updates models weights using new data. 360 | 361 | Args: 362 | df (pl.DataFrame): DataFrame containing data for updating the models. 363 | 364 | """ 365 | for i, model in enumerate(self.models): 366 | if self.refit_models[i]: 367 | model.update(df) 368 | 369 | 370 | def load(self) -> None: 371 | """Loads all models in the ensemble from disk. 372 | """ 373 | for model in self.models: 374 | model.model.load() 375 | 376 | def save(self) -> None: 377 | """Saves all models in the ensemble to disk. 378 | """ 379 | for model in self.models: 380 | model.model.save() 381 | 382 | def get_params(self, deep: bool = True) -> dict: 383 | """Gets parameters for the ensemble. 384 | 385 | Args: 386 | deep (bool): Whether to return parameters of sub-objects. 387 | 388 | Returns: 389 | dict: Dictionary of parameters. 390 | """ 391 | return { 392 | "models": self.models, 393 | "weights": self.weights, 394 | "refit_models": self.refit_models, 395 | "col_target": self.col_target, 396 | } 397 | 398 | def set_params(self, **parameters): 399 | """Sets the parameters of the ensemble. 400 | 401 | Args: 402 | parameters (dict): A dictionary of parameter names and values. 403 | 404 | Returns: 405 | self: The updated ensemble. 406 | """ 407 | for parameter, value in parameters.items(): 408 | setattr(self, parameter, value) 409 | return self 410 | 411 | class PipelineCV: 412 | """Cross-validation pipeline for time series models. 413 | 414 | This class manages cross-validation for time series models, allowing for fitting models on 415 | multiple folds, tracking results, and handling time-based data splits. 416 | 417 | Attributes: 418 | model (FullPipeline): The model to be validated. 419 | tracker (WandbTracker): Tracker for logging metrics during cross-validation. 420 | n_splits (int): Number of cross-validation splits. 421 | train_size (int): Maximum size of the training set. 422 | models (list): List of models fitted on each fold. 423 | """ 424 | def __init__( 425 | self, 426 | model: FullPipeline, 427 | tracker: WandbTracker, 428 | n_splits: int, 429 | train_size: int = False, 430 | ) -> None: 431 | """Initializes the PipelineCV. 432 | 433 | Args: 434 | model (FullPipeline): The model to be validated. 435 | tracker (WandbTracker): Tracker for logging metrics. 436 | n_splits (int): Number of cross-validation splits. 437 | train_size (int, optional): Maximum size of the training set. Defaults to False. 438 | """ 439 | self.model = model 440 | self.tracker = tracker 441 | self.n_splits = n_splits 442 | self.train_size = train_size 443 | self.models = [] 444 | 445 | def fit( 446 | self, 447 | df: pl.DataFrame, 448 | verbose: bool = False, 449 | ) -> list: 450 | """Fits models on cross-validation folds. 451 | 452 | Args: 453 | df (pl.DataFrame): DataFrame containing the data. 454 | verbose (bool, optional): Whether to print verbose output. Defaults to False. 455 | 456 | Returns: 457 | list: Scores for each fold. 458 | """ 459 | dates_unique = df.select(pl.col(COL_DATE).unique().sort()).to_series().to_numpy() 460 | 461 | test_size = ( 462 | TEST_SIZE 463 | if len(dates_unique) > TEST_SIZE * (self.n_splits + 1) 464 | else len(dates_unique) // (self.n_splits + 1) 465 | ) # For testing purposes on small samples 466 | cv = TimeSeriesSplit( 467 | n_splits=self.n_splits, 468 | test_size=test_size, 469 | max_train_size=self.train_size 470 | ) 471 | cv_split = cv.split(dates_unique) 472 | 473 | scores = [] 474 | for fold, (train_idx, valid_idx) in enumerate(cv_split): 475 | if verbose: 476 | print("-"*20 + f"Fold {fold}" + "-"*20) 477 | print( 478 | f"Train dates from {dates_unique[train_idx].min()}" 479 | f" to {dates_unique[train_idx].max()}" 480 | ) 481 | print( 482 | f"Valid dates from {dates_unique[valid_idx].min()}" 483 | f" to {dates_unique[valid_idx].max()}" 484 | ) 485 | 486 | dates_train = dates_unique[train_idx] 487 | dates_valid = dates_unique[valid_idx] 488 | 489 | df_train = df.filter(pl.col(COL_DATE).is_in(dates_train)) 490 | df_valid = df.filter(pl.col(COL_DATE).is_in(dates_valid)) 491 | 492 | model_fold = clone(self.model) 493 | model_fold.set_run_name(f"fold{fold}") 494 | model_fold.fit(df_train, df_valid, verbose=verbose) 495 | 496 | self.models.append(model_fold) 497 | 498 | preds = [] 499 | cnt_dates = 0 500 | model_save = copy.deepcopy(model_fold) 501 | for date_id in tqdm(dates_valid): 502 | df_valid_date = df_valid.filter(pl.col(COL_DATE) == date_id) 503 | 504 | if model_fold.refit & (cnt_dates > 0): 505 | df_upd = df.filter(pl.col(COL_DATE)==date_id-1) 506 | if len(df_upd) > 0: 507 | model_save.update(df_upd) 508 | 509 | preds_i, _ = model_save.predict(df_valid_date) 510 | preds += list(preds_i) 511 | cnt_dates += 1 512 | preds = np.array(preds) 513 | 514 | df_valid = df_valid.fill_null(0.0) 515 | y_true = df_valid.select(pl.col(model_fold.col_target)).to_series().to_numpy() 516 | weights = df_valid.select(pl.col(COL_WEIGHT)).to_series().to_numpy() 517 | score = r2_weighted(y_true, preds, weights) 518 | scores.append(score) 519 | 520 | print(f"R2: {score:.5f}") 521 | if self.tracker: 522 | self.tracker.log_metrics({f"fold_{fold}": score}) 523 | 524 | if self.tracker: 525 | self.tracker.log_metrics({"cv": np.mean(scores)}) 526 | return scores 527 | 528 | def load(self) -> None: 529 | """Loads models for each fold from disk. 530 | """ 531 | self.models = [] 532 | for i in range(self.n_splits): 533 | model = clone(self.model) 534 | model.set_run_name(f"fold{i}") 535 | model.fit() 536 | self.models.append(model) 537 | -------------------------------------------------------------------------------- /janestreet/models/nn.py: -------------------------------------------------------------------------------- 1 | """Neural net. 2 | """ 3 | import copy 4 | import numpy as np 5 | from tqdm.auto import tqdm 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from janestreet.metrics import r2_weighted_torch, WeightedR2Loss 12 | 13 | 14 | def flatten_collate_fn(batch: list) -> tuple[torch.Tensor]: 15 | """ 16 | Collate function for DataLoader to flatten the batch. 17 | 18 | Args: 19 | batch (list): List of tuples containing tensors. 20 | 21 | tuple[torch.Tensor]: Flattened tensors (X, resp, y, weights). 22 | """ 23 | X, resp, y, weights = zip(*batch) 24 | X = torch.cat(X, dim=0) 25 | resp = torch.cat(resp, dim=0) 26 | y = torch.cat(y, dim=0) 27 | weights = torch.cat(weights, dim=0) 28 | 29 | return X, resp, y, weights 30 | 31 | 32 | class CustomTensorDataset(Dataset): 33 | """Dataset wrapping tensors, grouped by datetime. 34 | 35 | The dataset groups data by dates, reshapes it to 3D 36 | (`dates * stocks x time_ids x features`), stores it, 37 | and returns data for a single date. 38 | 39 | Args: 40 | X (ndarray or Tensor): Numerical features. 41 | resp (ndarray or Tensor): Auxiliary targets. 42 | y (ndarray or Tensor): Target variable. 43 | weights (ndarray or Tensor): Weights. 44 | symbols (ndarray or Tensor): Symbol IDs. 45 | dates (ndarray or Tensor): Date IDs. 46 | times (ndarray or Tensor): Time IDs. 47 | on_batch (bool): If True, data is reshaped when calling `__getitem__`. 48 | Required if the number of time IDs per date is not uniform. 49 | """ 50 | T = 968 51 | 52 | def __init__( 53 | self, 54 | X: np.array, 55 | resp: np.array, 56 | y: np.array, 57 | weights: np.array, 58 | symbols: np.array, 59 | dates: np.array, 60 | times: np.array, 61 | on_batch: bool = True, 62 | ): 63 | """ 64 | Initialize the CustomTensorDataset. 65 | 66 | Args: 67 | X (ndarray or Tensor): Numerical features. 68 | resp (ndarray or Tensor): Auxiliary targets. 69 | y (ndarray or Tensor): Target variable. 70 | weights (ndarray or Tensor): Weights. 71 | symbols (ndarray or Tensor): Symbol IDs. 72 | dates (ndarray or Tensor): Date IDs. 73 | times (ndarray or Tensor): Time IDs. 74 | on_batch (bool): If True, data is reshaped when calling `__getitem__`. 75 | Required if the number of time IDs per date is not uniform. 76 | """ 77 | self.on_batch = on_batch 78 | self.num_features = X.shape[1] 79 | 80 | self.X = torch.tensor(X, dtype=torch.float32) 81 | self.resp = torch.tensor(resp, dtype=torch.float32) 82 | self.y = torch.tensor(y, dtype=torch.float32) 83 | self.weights = torch.tensor(weights, dtype=torch.float32) 84 | self.symbols = torch.tensor(symbols, dtype=torch.int64) 85 | self.dates = torch.tensor(dates, dtype=torch.int64) 86 | self.times = torch.tensor(times, dtype=torch.int64) 87 | 88 | self.X = torch.nan_to_num(self.X, 0) 89 | 90 | self.K = X.shape[1] 91 | 92 | if not self.on_batch: 93 | T = self.T 94 | N, K = self.X.shape 95 | 96 | sorted_indices = torch.argsort(self.times, stable=True) 97 | sorted_indices = sorted_indices[torch.argsort(self.dates[sorted_indices], stable=True)] 98 | sorted_indices = sorted_indices[torch.argsort(self.symbols[sorted_indices], stable=True)] 99 | self.X = self.X[sorted_indices] 100 | self.resp = self.resp[sorted_indices] 101 | self.dates = self.dates[sorted_indices] 102 | self.y = self.y[sorted_indices] 103 | self.weights = self.weights[sorted_indices] 104 | self.symbols = self.symbols[sorted_indices] 105 | 106 | self.X = self.X.view(N//T, T, K) 107 | self.resp = self.resp.view(N//T, T, self.resp.shape[-1]) 108 | self.dates = self.dates.view(N//T, T)[:,0].squeeze() 109 | self.y = self.y.view(N//T, T) 110 | self.weights = self.weights.view(N//T, T) 111 | self.symbols = self.symbols.view(N//T, T) 112 | 113 | 114 | self.datetime_ids = self.dates 115 | self.unique_datetimes, self.inverse_indices, self.counts = torch.unique( 116 | self.datetime_ids, return_inverse=True, return_counts=True 117 | ) 118 | 119 | self.sorted_indices = torch.argsort(self.inverse_indices) 120 | self.group_end_indices = torch.cumsum(self.counts, dim=0) 121 | self.group_start_indices = torch.cat((torch.tensor([0]), self.group_end_indices[:-1])) 122 | 123 | def __getitem__(self, index: int) -> tuple[torch.Tensor]: 124 | """ 125 | Get the data for a specific index (date id). 126 | 127 | Args: 128 | index (int): Index of the date. 129 | 130 | Returns: 131 | tuple[torch.Tensor]: A tuple containing X, y, resp, and weights for the specified index. 132 | """ 133 | start = self.group_start_indices[index] 134 | end = self.group_end_indices[index] 135 | index = self.sorted_indices[start:end] 136 | 137 | X = self.X[index] 138 | resp = self.resp[index] 139 | y = self.y[index] 140 | weights = self.weights[index] 141 | 142 | if self.on_batch: 143 | T = max(self.times[index])+1 144 | X = X.reshape(T, -1, self.K).swapaxes(0, 1) 145 | resp = resp.reshape(T, -1, resp.shape[1]).swapaxes(0, 1) 146 | y = y.reshape(T, -1).swapaxes(0, 1) 147 | weights = weights.reshape(T, -1).swapaxes(0, 1) 148 | 149 | return X, resp, y, weights 150 | 151 | def __len__(self) -> int: 152 | """Get the length of the dataset.""" 153 | return len(self.unique_datetimes) 154 | 155 | 156 | class ModelRBase(nn.Module): 157 | """Base recurrent model. 158 | 159 | This class defines a recurrent neural network with GRU or LSTM layers, 160 | followed by fully connected (linear) layers. Dropout can be applied 161 | after each recurrent and linear layer. 162 | 163 | Args: 164 | input_size (int): Number of input features. 165 | hidden_sizes (list): List of hidden sizes for the recurrent layers. 166 | dropout_rates (list): List of dropout rates for the recurrent layers. 167 | hidden_sizes_linear (list): List of hidden sizes for the linear layers. 168 | dropout_rates_linear (list): List of dropout rates for the linear layers. 169 | model_type (str): Type of the model, either "gru" or "lstm". 170 | 171 | Raises: 172 | ValueError: If `model_type` is not "gru" or "lstm". 173 | """ 174 | def __init__( 175 | self, 176 | input_size: int, 177 | hidden_sizes: list, 178 | dropout_rates: list, 179 | hidden_sizes_linear: list, 180 | dropout_rates_linear: list, 181 | model_type: str 182 | ) -> None: 183 | """ 184 | Initializes the ModelRBase class. 185 | 186 | Args: 187 | input_size (int): Number of input features. 188 | hidden_sizes (list): List of hidden sizes for the recurrent layers. 189 | dropout_rates (list): List of dropout rates for the recurrent layers. 190 | hidden_sizes_linear (list): List of hidden sizes for the linear layers. 191 | dropout_rates_linear (list): List of dropout rates for the linear layers. 192 | model_type (str): Type of the model, either "gru" or "lstm". 193 | 194 | Raises: 195 | ValueError: If `model_type` is not "gru" or "lstm". 196 | """ 197 | super(ModelRBase, self).__init__() 198 | self.num_layers = len(hidden_sizes) 199 | 200 | self.gru_layers = nn.ModuleList() 201 | self.dropout_rates = nn.ModuleList() 202 | for i in range(self.num_layers): 203 | input_dim = input_size if i == 0 else hidden_sizes[i - 1] 204 | if model_type == "gru": 205 | layer = nn.GRU(input_dim, hidden_sizes[i], num_layers=1, batch_first=True) 206 | elif model_type == "lstm": 207 | layer = nn.LSTM(input_dim, hidden_sizes[i], num_layers=1, batch_first=True) 208 | else: 209 | raise ValueError("Unknown model type") 210 | self.gru_layers.append(layer) 211 | self.dropout_rates.append(nn.Dropout(dropout_rates[i])) 212 | 213 | if self.num_layers == 0: 214 | n_input_linear = input_size 215 | else: 216 | n_input_linear = hidden_sizes[-1] 217 | 218 | fc_layers = [] 219 | if hidden_sizes_linear: 220 | for i in range(len(hidden_sizes_linear)): 221 | in_features = n_input_linear if i == 0 else hidden_sizes_linear[i - 1] 222 | fc_layers.append(nn.Linear(in_features, hidden_sizes_linear[i])) 223 | fc_layers.append(nn.ReLU()) 224 | fc_layers.append(nn.Dropout(dropout_rates_linear[i])) 225 | fc_layers.append(nn.Linear(hidden_sizes_linear[-1], 1)) 226 | else: 227 | fc_layers.append(nn.Linear(n_input_linear, 1)) 228 | 229 | self.fc = nn.Sequential(*fc_layers) 230 | 231 | def forward(self, x: torch.Tensor, hidden: bool = None) -> tuple[torch.Tensor, torch.Tensor]: 232 | """ 233 | Forward pass of the model. 234 | 235 | Args: 236 | x (torch.Tensor): Input tensor of shape (D, T, input_size), 237 | where D is the batch size, T is the sequence length, 238 | and `input_size` is the number of features. 239 | hidden (bool, optional): Initial hidden state for the recurrent layers. 240 | Defaults to None. 241 | 242 | Returns: 243 | tuple[torch.Tensor, torch.Tensor]: A tuple containing: 244 | - Output tensor of shape (D, T), where D is the batch size and 245 | T is the sequence length. 246 | - Hidden state tensor from the last recurrent layer. 247 | """ 248 | D, T, _ = x.shape 249 | 250 | if hidden is None: 251 | hidden = [None] * self.num_layers 252 | 253 | for i, gru in enumerate(self.gru_layers): 254 | x, h = gru(x, hidden[i]) 255 | if hasattr(self, "dropout_rates"): 256 | x = self.dropout_rates[i](x) 257 | hidden[i] = h 258 | 259 | x = x.reshape(D * T, -1) 260 | x = self.fc(x) 261 | x = x.reshape(D, T) 262 | 263 | return x, hidden 264 | 265 | 266 | class ModelR(nn.Module): 267 | """Recurrent model with auxiliary targets. 268 | 269 | This model uses multiple recurrent networks (GRU or LSTM) to predict both 270 | the primary target and auxiliary targets. The auxiliary targets are 271 | combined using a fully connected layer to produce the final output. 272 | 273 | Args: 274 | input_size (int): Number of input features. 275 | hidden_sizes (list): List of hidden sizes for the recurrent layers. 276 | dropout_rates (list): List of dropout rates for the recurrent layers. 277 | hidden_sizes_linear (list): List of hidden sizes for the linear layers. 278 | dropout_rates_linear (list): List of dropout rates for the linear layers. 279 | model_type (str): Type of the model, either "gru" or "lstm". 280 | """ 281 | def __init__( 282 | self, 283 | input_size: int, 284 | hidden_sizes: list, 285 | dropout_rates: list, 286 | hidden_sizes_linear: list, 287 | dropout_rates_linear: list, 288 | model_type: str 289 | ): 290 | """ 291 | Initializes the ModelR class. 292 | 293 | Args: 294 | input_size (int): Number of input features. 295 | hidden_sizes (list): List of hidden sizes for the recurrent layers. 296 | dropout_rates (list): List of dropout rates for the recurrent layers. 297 | hidden_sizes_linear (list): List of hidden sizes for the linear layers. 298 | dropout_rates_linear (list): List of dropout rates for the linear layers. 299 | model_type (str): Type of the model, either "gru" or "lstm". 300 | """ 301 | super(ModelR, self).__init__() 302 | self.num_resp = 4 303 | 304 | self.grus = nn.ModuleList() 305 | self.fcs = nn.ModuleList() 306 | for _ in range(self.num_resp): 307 | self.grus.append( 308 | ModelRBase( 309 | input_size, 310 | hidden_sizes, 311 | dropout_rates, 312 | hidden_sizes_linear, 313 | dropout_rates_linear, 314 | model_type 315 | ) 316 | ) 317 | self.out = nn.Sequential( 318 | nn.Linear(self.num_resp, 1), 319 | ) 320 | 321 | def forward( 322 | self, 323 | x: torch.Tensor, 324 | hidden: torch.Tensor | None = None 325 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 326 | """ 327 | Forward pass of the model. 328 | 329 | Args: 330 | x (torch.Tensor): Input tensor of shape (D, T, input_size), 331 | where D is the batch size, T is the sequence length, 332 | and `input_size` is the number of features. 333 | hidden (torch.Tensor or None, optional): Initial hidden state for 334 | the recurrent layers. Defaults to None. 335 | 336 | Returns: 337 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: 338 | - Output tensor (torch.Tensor): Final output tensor of shape (D, T). 339 | - Auxiliary targets (torch.Tensor): Tensor of shape (D, T, num_resp), 340 | where `num_resp` is the number of auxiliary targets. 341 | - Hidden state (torch.Tensor): Hidden state from the last recurrent layer. 342 | """ 343 | D, T, _ = x.shape 344 | 345 | if hidden is None: 346 | hidden = [None] * self.num_resp 347 | 348 | out = [] 349 | for i in range(len(self.grus)): 350 | z, h = self.grus[i](x, hidden[i]) 351 | out.append(z) 352 | out[i] = out[i].reshape(D*T, -1) 353 | hidden[i] = h 354 | out_resp = torch.cat(out, dim=-1) 355 | y = self.out(out_resp) 356 | 357 | out_resp = out_resp.reshape(D, T, -1) 358 | y = y.reshape(D, T) 359 | 360 | return y, out_resp, hidden 361 | 362 | class NN: 363 | """Neural network model for time series data with auxiliary targets. 364 | 365 | This class defines a recurrent neural network (GRU or LSTM) with 366 | optional linear layers and dropout. It includes functionality for 367 | training, validation, updating with new data, and making predictions. 368 | 369 | Args: 370 | model_type (str, optional): Type of recurrent model, either "gru" or "lstm". 371 | hidden_sizes (list, optional): List of hidden sizes for the recurrent layers. 372 | dropout_rates (list, optional): List of dropout rates for the recurrent layers. 373 | hidden_sizes_linear (list, optional): List of hidden sizes for the linear layers. 374 | dropout_rates_linear (list, optional): List of dropout rates for the linear layers. 375 | lr (float, optional): Learning rate. Defaults to 0.001. 376 | batch_size (int, optional): Batch size for training. Defaults to 1. 377 | epochs (int, optional): Number of epochs for training. Defaults to 100. 378 | early_stopping_patience (int, optional): Patience for early stopping. Defaults to 10. 379 | early_stopping (bool, optional): Whether to use early stopping. Defaults to True. 380 | lr_patience (int, optional): Patience for learning rate reduction. Defaults to 2. 381 | lr_factor (float, optional): Factor for reducing learning rate. Defaults to 0.5. 382 | lr_refit (float, optional): Learning rate for model refitting. Defaults to 0.001. 383 | random_seed (int, optional): Random seed for reproducibility. Defaults to 42. 384 | """ 385 | def __init__( 386 | self, 387 | model_type: str | None = None, 388 | hidden_sizes: list | None = None, 389 | dropout_rates: list | None = None, 390 | hidden_sizes_linear: list | None = None, 391 | dropout_rates_linear: list | None = None, 392 | lr: float = 0.001, 393 | batch_size: int = 1, 394 | epochs: int = 100, 395 | early_stopping_patience: int = 10, 396 | early_stopping: bool = True, 397 | lr_patience: int = 2, 398 | lr_factor: float = 0.5, 399 | lr_refit: float = 0.001, 400 | random_seed: int = 42, 401 | ) -> None: 402 | """Initialize the NN model. 403 | 404 | Args: 405 | model_type (str, optional): Type of recurrent model, either "gru" or "lstm". 406 | hidden_sizes (list, optional): List of hidden sizes for the recurrent layers. 407 | dropout_rates (list, optional): List of dropout rates for the recurrent layers. 408 | hidden_sizes_linear (list, optional): List of hidden sizes for the linear layers. 409 | dropout_rates_linear (list, optional): List of dropout rates for the linear layers. 410 | lr (float, optional): Learning rate. Defaults to 0.001. 411 | batch_size (int, optional): Batch size for training. Defaults to 1. 412 | epochs (int, optional): Number of epochs for training. Defaults to 100. 413 | early_stopping_patience (int, optional): Patience for early stopping. Defaults to 10. 414 | early_stopping (bool, optional): Whether to use early stopping. Defaults to True. 415 | lr_patience (int, optional): Patience for learning rate reduction. Defaults to 2. 416 | lr_factor (float, optional): Factor for reducing learning rate. Defaults to 0.5. 417 | lr_refit (float, optional): Learning rate for model refitting. Defaults to 0.001. 418 | random_seed (int, optional): Random seed for reproducibility. Defaults to 42. 419 | """ 420 | self.model_type = model_type 421 | self.hidden_sizes = hidden_sizes 422 | self.dropout_rates = dropout_rates 423 | self.hidden_sizes_linear = hidden_sizes_linear 424 | self.dropout_rates_linear = dropout_rates_linear 425 | self.lr = lr 426 | self.batch_size = batch_size 427 | self.epochs = epochs 428 | self.early_stopping_patience = early_stopping_patience 429 | self.early_stopping = early_stopping 430 | self.lr_patience = lr_patience 431 | self.lr_factor = lr_factor 432 | self.lr_refit = lr_refit 433 | self.random_seed = random_seed 434 | 435 | self.criterion = WeightedR2Loss() 436 | 437 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 438 | self.model = None 439 | self.optimizer = None 440 | self.best_epoch = None 441 | self.features = None 442 | 443 | def fit(self, train_set: tuple, val_set: tuple, verbose: bool = False) -> None: 444 | """Fit the model on the training set and validate on the validation set. 445 | 446 | Args: 447 | train_set (tuple): A tuple containing input data, targets, and weights for training. 448 | val_set (tuple): A tuple containing input data, targets, and weights for validation. 449 | verbose (bool, optional): If True, prints training progress. Defaults to False. 450 | """ 451 | torch.manual_seed(self.random_seed) 452 | 453 | train_dataset = CustomTensorDataset(*train_set, on_batch=False) 454 | train_dataloader = DataLoader( 455 | train_dataset, 456 | batch_size=self.batch_size, 457 | shuffle=True, 458 | collate_fn=flatten_collate_fn 459 | ) 460 | 461 | val_dataset = CustomTensorDataset(*val_set, on_batch=True) 462 | val_dataloader = DataLoader( 463 | val_dataset, 464 | batch_size=1, 465 | shuffle=False, 466 | collate_fn=flatten_collate_fn 467 | ) 468 | 469 | self.model = ModelR( 470 | train_dataset.num_features, 471 | self.hidden_sizes, 472 | self.dropout_rates, 473 | self.hidden_sizes_linear, 474 | self.dropout_rates_linear, 475 | self.model_type 476 | ).to(self.device) 477 | 478 | self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=0.01) 479 | 480 | train_r2s, val_r2s = [], [] 481 | if verbose: 482 | print(f"Device: {self.device}") 483 | print( 484 | f"{'Epoch':^5} | {'Train Loss':^10} | {'Val Loss':^8} " 485 | f"| {'Train R2':^9} | {'Val R2':^7} | {'LR':^7}" 486 | ) 487 | print("-" * 60) 488 | 489 | min_val_r2 = -np.inf 490 | best_epoch = 0 491 | no_improvement = 0 492 | best_model = None 493 | for epoch in range(self.epochs): 494 | train_loss, train_r2 = self.train_one_epoch(train_dataloader, verbose) 495 | val_loss, val_r2 = self.validate_one_epoch(val_dataloader, verbose) 496 | lr_last = self.optimizer.param_groups[0]["lr"] 497 | 498 | train_r2s.append(train_r2) 499 | val_r2s.append(val_r2) 500 | 501 | if verbose: 502 | print( 503 | f"{epoch+1:^5} | {train_loss:^10.4f} | {val_loss:^8.4f} | " 504 | f"{train_r2:^9.4f} | {val_r2:^7.4f} | {lr_last:^7.5f}" 505 | ) 506 | 507 | if val_r2 > min_val_r2: 508 | min_val_r2 = val_r2 509 | best_model = copy.deepcopy(self.model.state_dict()) 510 | no_improvement = 0 511 | best_epoch = epoch 512 | else: 513 | no_improvement += 1 514 | 515 | if self.early_stopping: 516 | if no_improvement >= self.early_stopping_patience + 1: 517 | self.best_epoch = best_epoch+1 518 | if verbose: 519 | print( 520 | f"Early stopping on epoch {best_epoch+1}. " 521 | f"Best score: {min_val_r2:.4f}" 522 | ) 523 | break 524 | 525 | # Load the best model 526 | if self.early_stopping: 527 | self.model.load_state_dict(best_model) 528 | 529 | 530 | def train_one_epoch(self, train_dataloader: DataLoader, verbose: bool) -> None: 531 | """Train the model for one epoch. 532 | 533 | Args: 534 | train_dataloader (DataLoader): DataLoader for the training set. 535 | verbose (bool): If True, shows progress using tqdm. 536 | 537 | Returns: 538 | tuple[float, float]: A tuple containing: 539 | - Train loss (float). 540 | - Weighted R² score for the training set (float). 541 | """ 542 | self.model.train() 543 | total_loss = 0.0 544 | 545 | y_total, weights_total, preds_total = [], [], [] 546 | if verbose: 547 | itr = tqdm(train_dataloader) 548 | else: 549 | itr = train_dataloader 550 | 551 | for x_batch, resp_batch, y_batch, weights_batch in itr: 552 | x_batch, resp_batch, y_batch, weights_batch = ( 553 | item.to(self.device) 554 | for item in [x_batch, resp_batch, y_batch, weights_batch] 555 | ) 556 | 557 | self.optimizer.zero_grad() 558 | out_y, out_resp, _ = self.model(x_batch, None) 559 | loss1 = self.criterion(out_y.flatten(), y_batch.flatten(), weights_batch.flatten()) 560 | loss2 = self.criterion( 561 | out_resp[:, :, 0].flatten(), 562 | resp_batch[:, :, -1].flatten(), 563 | weights_batch.flatten() 564 | ) 565 | loss3 = self.criterion( 566 | out_resp[:, :, 1].flatten(), 567 | resp_batch[:, :, -2].flatten(), 568 | weights_batch.flatten() 569 | ) 570 | loss4 = self.criterion( 571 | out_resp[:, :, 2].flatten(), 572 | resp_batch[:, :, -3].flatten(), 573 | weights_batch.flatten() 574 | ) 575 | loss5 = self.criterion( 576 | out_resp[:, :, 3].flatten(), 577 | resp_batch[:, :, -4].flatten(), 578 | weights_batch.flatten() 579 | ) 580 | loss = loss1+loss2+loss3+loss4+loss5 581 | loss.backward() 582 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) 583 | 584 | self.optimizer.step() 585 | 586 | total_loss += loss.item() 587 | 588 | y_total.append(y_batch.flatten()) 589 | weights_total.append(weights_batch.flatten()) 590 | preds_total.append(out_y.detach().flatten()) 591 | 592 | y_total = torch.cat(y_total).cpu() 593 | weights_total = torch.cat(weights_total).cpu() 594 | preds_total = torch.cat(preds_total).cpu() 595 | 596 | train_r2 = r2_weighted_torch(y_total, preds_total, weights_total).item() 597 | train_loss = total_loss / len(train_dataloader) 598 | 599 | return train_loss, train_r2 600 | 601 | 602 | def validate_one_epoch(self, val_dataloader: DataLoader, verbose=False) -> None: 603 | """Validate the model on the validation set. 604 | 605 | Args: 606 | val_dataloader (DataLoader): DataLoader for the validation set. 607 | verbose (bool, optional): If True, shows progress using tqdm. Defaults to False. 608 | 609 | Returns: 610 | tuple[float, float]: A tuple containing: 611 | - Validation loss (float). 612 | - Weighted R² score for the validation set (float). 613 | """ 614 | model = copy.deepcopy(self.model) 615 | 616 | losses, all_y, all_weights, all_preds = [], [], [], [] 617 | 618 | if verbose: 619 | itr = tqdm(val_dataloader) 620 | else: 621 | itr = val_dataloader 622 | for x_batch, resp_batch, y_batch, weights_batch in itr: 623 | x_batch, resp_batch, y_batch, weights_batch = ( 624 | item.to(self.device) 625 | for item in [x_batch, resp_batch, y_batch, weights_batch] 626 | ) 627 | 628 | # Predict 629 | with torch.no_grad(): 630 | model.eval() 631 | preds_batch, _, _ = model(x_batch, None) 632 | loss = self.criterion( 633 | preds_batch.flatten(), 634 | y_batch.flatten(), 635 | weights_batch.flatten() 636 | ) 637 | losses.append(loss.item()) 638 | 639 | all_y.append(y_batch.flatten()) 640 | all_weights.append(weights_batch.flatten()) 641 | all_preds.append(preds_batch.flatten()) 642 | 643 | # Update weights 644 | if self.lr_refit > 0: 645 | optimizer = torch.optim.AdamW( 646 | model.parameters(), 647 | lr=self.lr_refit, 648 | weight_decay=0.01 649 | ) 650 | optimizer.zero_grad() 651 | model.train() 652 | out_y, _, _ = model(x_batch, None) 653 | loss = self.criterion(out_y, y_batch, weights_batch) 654 | loss.backward() 655 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 656 | optimizer.step() 657 | 658 | all_y = torch.cat(all_y) 659 | all_weights = torch.cat(all_weights) 660 | all_preds = torch.cat(all_preds) 661 | loss = np.mean(losses) 662 | r2 = r2_weighted_torch(all_y, all_preds, all_weights).item() 663 | 664 | return loss, r2 665 | 666 | def update( 667 | self, 668 | X: np.array, 669 | y: np.array, 670 | weights: np.array, 671 | n_times: int, 672 | ): 673 | """Update the model with new data. 674 | 675 | Args: 676 | X (np.array): Input data. 677 | y (np.array): Target variable. 678 | weights (np.array): Weights for the target variable. 679 | n_times (int): Number of time steps. 680 | """ 681 | if self.lr_refit == 0.0: 682 | return 683 | 684 | X = torch.tensor(X, dtype=torch.float32) 685 | y = torch.tensor(y, dtype=torch.float32) 686 | weights = torch.tensor(weights, dtype=torch.float32) 687 | 688 | N, K = X.shape 689 | X = X.view(n_times, N//n_times, K).swapaxes(0, 1).to(self.device) 690 | y = y.view(n_times, N//n_times).swapaxes(0, 1).to(self.device) 691 | weights = weights.view(n_times, N//n_times).swapaxes(0, 1).to(self.device) 692 | 693 | self.model.train() 694 | 695 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr_refit, weight_decay=0.01) 696 | optimizer.zero_grad() 697 | 698 | out_y, _, _ = self.model(X, None) 699 | loss = self.criterion(out_y.flatten(), y.flatten(), weights.flatten()) 700 | loss.backward() 701 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) 702 | optimizer.step() 703 | 704 | def predict( 705 | self, 706 | X: np.array, 707 | n_times: int = None, 708 | hidden: torch.Tensor | list | None = None 709 | ) -> tuple[np.array, torch.Tensor | list]: 710 | """Predict the target variable for the given input data. 711 | 712 | Args: 713 | X (np.array): Input data. 714 | n_times (int, optional): Number of time steps. Defaults to None. 715 | hidden (torch.Tensor or list or None, optional): Initial hidden state. Defaults to None. 716 | 717 | Returns: 718 | tuple[np.array, torch.Tensor or list]: A tuple containing: 719 | - Predictions (np.array). 720 | - Hidden state (torch.Tensor or list). 721 | """ 722 | X_tensor = torch.tensor(X, dtype=torch.float32) 723 | 724 | N, K = X.shape 725 | X_tensor = X_tensor.view(n_times, N//n_times, K).swapaxes(0, 1).to(self.device) 726 | 727 | X_tensor = torch.nan_to_num(X_tensor, 0) 728 | self.model.eval() 729 | with torch.no_grad(): 730 | preds, _, hidden = self.model(X_tensor, hidden) 731 | preds = preds.swapaxes(0, 1) 732 | preds = preds.reshape(-1).cpu().numpy() 733 | return preds, hidden 734 | 735 | def get_params(self, deep: bool = True): 736 | """Get parameters for this estimator. 737 | 738 | Args: 739 | deep (bool): If True, will return the parameters for this 740 | estimator and contained subobjects that are estimators. 741 | """ 742 | return { 743 | "model_type": self.model_type, 744 | "hidden_sizes": self.hidden_sizes, 745 | "dropout_rates": self.dropout_rates, 746 | "hidden_sizes_linear": self.hidden_sizes_linear, 747 | "dropout_rates_linear": self.dropout_rates_linear, 748 | "lr": self.lr, 749 | "batch_size": self.batch_size, 750 | "epochs": self.epochs, 751 | "early_stopping_patience": self.early_stopping_patience, 752 | "early_stopping": self.early_stopping, 753 | "lr_patience": self.lr_patience, 754 | "lr_factor": self.lr_factor, 755 | "lr_refit": self.lr_refit, 756 | "random_seed": self.random_seed 757 | } 758 | 759 | def set_params(self, **parameters): 760 | """Set the parameters of this estimator. 761 | 762 | Args: 763 | parameters (dict): A dictionary of parameters to set, mapping parameter 764 | names to their new values. 765 | """ 766 | for parameter, value in parameters.items(): 767 | setattr(self, parameter, value) 768 | return self 769 | --------------------------------------------------------------------------------